diff --git a/src/frontend/A64/translate/impl/simd_permute.cpp b/src/frontend/A64/translate/impl/simd_permute.cpp index 1c936903..bf677156 100644 --- a/src/frontend/A64/translate/impl/simd_permute.cpp +++ b/src/frontend/A64/translate/impl/simd_permute.cpp @@ -13,8 +13,7 @@ enum class Transposition { TRN2, }; -bool VectorTranspose(TranslatorVisitor& v, bool Q, Imm<2> size, Vec Vm, Vec Vn, Vec Vd, - Transposition type) { +bool VectorTranspose(TranslatorVisitor& v, bool Q, Imm<2> size, Vec Vm, Vec Vn, Vec Vd, Transposition type) { if (!Q && size == 0b11) { return v.ReservedValue(); } @@ -24,44 +23,7 @@ bool VectorTranspose(TranslatorVisitor& v, bool Q, Imm<2> size, Vec Vm, Vec Vn, const IR::U128 m = v.V(datasize, Vm); const IR::U128 n = v.V(datasize, Vn); - - const IR::U128 result = [&] { - switch (esize) { - case 8: - case 16: - case 32: { - // Create a mask of elements we care about (e.g. for 8-bit: 0x00FF00FF00FF00FF for TRN1 - // and 0xFF00FF00FF00FF00 for TRN2) - const u64 mask_element = [&] { - const size_t shift = type == Transposition::TRN1 ? 0 : esize; - return Common::Ones(esize) << shift; - }(); - const size_t doubled_esize = esize * 2; - const u64 mask_value = Common::Replicate(mask_element, doubled_esize); - - const IR::U128 mask = v.ir.VectorBroadcast(64, v.I(64, mask_value)); - const IR::U128 anded_m = v.ir.VectorAnd(m, mask); - const IR::U128 anded_n = v.ir.VectorAnd(n, mask); - - if (type == Transposition::TRN1) { - return v.ir.VectorOr(v.ir.VectorLogicalShiftLeft(doubled_esize, anded_m, esize), anded_n); - } - - return v.ir.VectorOr(v.ir.VectorLogicalShiftRight(doubled_esize, anded_n, esize), anded_m); - } - case 64: { - default: - const auto [src, src_idx, dst, dst_idx] = [type, m, n] { - if (type == Transposition::TRN1) { - return std::make_tuple(m, 0, n, 1); - } - return std::make_tuple(n, 1, m, 0); - }(); - - return v.ir.VectorSetElement(esize, dst, dst_idx, v.ir.VectorGetElement(esize, src, src_idx)); - } - } - }(); + const IR::U128 result = v.ir.VectorTranspose(esize, n, m, type == Transposition::TRN2); v.V(datasize, Vd, result); return true;