diff --git a/src/dynarmic/backend/x64/emit_x64_sha.cpp b/src/dynarmic/backend/x64/emit_x64_sha.cpp index 2f1b5dd0..92e0841d 100644 --- a/src/dynarmic/backend/x64/emit_x64_sha.cpp +++ b/src/dynarmic/backend/x64/emit_x64_sha.cpp @@ -48,4 +48,34 @@ void EmitX64::EmitSHA256Hash(EmitContext& ctx, IR::Inst* inst) { ctx.reg_alloc.DefineValue(inst, y); } +void EmitX64::EmitSHA256MessageSchedule0(EmitContext& ctx, IR::Inst* inst) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + ASSERT(code.HasHostFeature(HostFeature::SHA)); + + const Xbyak::Xmm x = ctx.reg_alloc.UseScratchXmm(args[0]); + const Xbyak::Xmm y = ctx.reg_alloc.UseXmm(args[1]); + + code.sha256msg1(x, y); + + ctx.reg_alloc.DefineValue(inst, x); +} + +void EmitX64::EmitSHA256MessageSchedule1(EmitContext& ctx, IR::Inst* inst) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + ASSERT(code.HasHostFeature(HostFeature::SHA)); + + const Xbyak::Xmm x = ctx.reg_alloc.UseScratchXmm(args[0]); + const Xbyak::Xmm y = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm z = ctx.reg_alloc.UseXmm(args[2]); + + code.movaps(xmm0, z); + code.palignr(xmm0, y, 4); + code.paddd(x, xmm0); + code.sha256msg2(x, z); + + ctx.reg_alloc.DefineValue(inst, x); +} + } // namespace Dynarmic::Backend::X64 diff --git a/src/dynarmic/frontend/A64/translate/impl/simd_sha.cpp b/src/dynarmic/frontend/A64/translate/impl/simd_sha.cpp index 69e0dc24..7d586b3a 100644 --- a/src/dynarmic/frontend/A64/translate/impl/simd_sha.cpp +++ b/src/dynarmic/frontend/A64/translate/impl/simd_sha.cpp @@ -114,72 +114,21 @@ bool TranslatorVisitor::SHA1H(Vec Vn, Vec Vd) { } bool TranslatorVisitor::SHA256SU0(Vec Vn, Vec Vd) { - const IR::U128 d = ir.GetQ(Vd); - const IR::U128 n = ir.GetQ(Vn); + const IR::U128 x = ir.GetQ(Vd); + const IR::U128 y = ir.GetQ(Vn); - const IR::U128 t = [&] { - // Shuffle the upper three elements down: [3, 2, 1, 0] -> [0, 3, 2, 1] - const IR::U128 shuffled = ir.VectorShuffleWords(d, 0b00111001); - - return ir.VectorSetElement(32, shuffled, 3, ir.VectorGetElement(32, n, 0)); - }(); - - IR::U128 result = ir.ZeroVector(); - for (size_t i = 0; i < 4; i++) { - const IR::U32 modified_element = [&] { - const IR::U32 element = ir.VectorGetElement(32, t, i); - const IR::U32 tmp1 = ir.RotateRight(element, ir.Imm8(7)); - const IR::U32 tmp2 = ir.RotateRight(element, ir.Imm8(18)); - const IR::U32 tmp3 = ir.LogicalShiftRight(element, ir.Imm8(3)); - - return ir.Eor(tmp1, ir.Eor(tmp2, tmp3)); - }(); - - const IR::U32 d_element = ir.VectorGetElement(32, d, i); - result = ir.VectorSetElement(32, result, i, ir.Add(modified_element, d_element)); - } + const IR::U128 result = ir.SHA256MessageSchedule0(x, y); ir.SetQ(Vd, result); return true; } bool TranslatorVisitor::SHA256SU1(Vec Vm, Vec Vn, Vec Vd) { - const IR::U128 d = ir.GetQ(Vd); - const IR::U128 m = ir.GetQ(Vm); - const IR::U128 n = ir.GetQ(Vn); + const IR::U128 x = ir.GetQ(Vd); + const IR::U128 y = ir.GetQ(Vn); + const IR::U128 z = ir.GetQ(Vm); - const IR::U128 T0 = [&] { - const IR::U32 low_m = ir.VectorGetElement(32, m, 0); - const IR::U128 shuffled_n = ir.VectorShuffleWords(n, 0b00111001); - - return ir.VectorSetElement(32, shuffled_n, 3, low_m); - }(); - - const IR::U128 lower_half = [&] { - const IR::U128 T = ir.VectorShuffleWords(m, 0b01001110); - const IR::U128 tmp1 = ir.VectorRotateRight(32, T, 17); - const IR::U128 tmp2 = ir.VectorRotateRight(32, T, 19); - const IR::U128 tmp3 = ir.VectorLogicalShiftRight(32, T, 10); - const IR::U128 tmp4 = ir.VectorEor(tmp1, ir.VectorEor(tmp2, tmp3)); - const IR::U128 tmp5 = ir.VectorAdd(32, tmp4, ir.VectorAdd(32, d, T0)); - return ir.VectorZeroUpper(tmp5); - }(); - - const IR::U64 upper_half = [&] { - const IR::U128 tmp1 = ir.VectorRotateRight(32, lower_half, 17); - const IR::U128 tmp2 = ir.VectorRotateRight(32, lower_half, 19); - const IR::U128 tmp3 = ir.VectorLogicalShiftRight(32, lower_half, 10); - const IR::U128 tmp4 = ir.VectorEor(tmp1, ir.VectorEor(tmp2, tmp3)); - - // Shuffle the top two 32-bit elements downwards [3, 2, 1, 0] -> [1, 0, 3, 2] - const IR::U128 shuffled_d = ir.VectorShuffleWords(d, 0b01001110); - const IR::U128 shuffled_T0 = ir.VectorShuffleWords(T0, 0b01001110); - - const IR::U128 tmp5 = ir.VectorAdd(32, tmp4, ir.VectorAdd(32, shuffled_d, shuffled_T0)); - return ir.VectorGetElement(64, tmp5, 0); - }(); - - const IR::U128 result = ir.VectorSetElement(64, lower_half, 1, upper_half); + const IR::U128 result = ir.SHA256MessageSchedule1(x, y, z); ir.SetQ(Vd, result); return true; diff --git a/src/dynarmic/ir/ir_emitter.cpp b/src/dynarmic/ir/ir_emitter.cpp index dbe49bf7..dfb03ae9 100644 --- a/src/dynarmic/ir/ir_emitter.cpp +++ b/src/dynarmic/ir/ir_emitter.cpp @@ -907,6 +907,14 @@ U128 IREmitter::SHA256Hash(const U128& x, const U128& y, const U128& w, bool par return Inst(Opcode::SHA256Hash, x, y, w, Imm1(part1)); } +U128 IREmitter::SHA256MessageSchedule0(const U128& x, const U128& y) { + return Inst(Opcode::SHA256MessageSchedule0, x, y); +} + +U128 IREmitter::SHA256MessageSchedule1(const U128& x, const U128& y, const U128& z) { + return Inst(Opcode::SHA256MessageSchedule1, x, y, z); +} + UAny IREmitter::VectorGetElement(size_t esize, const U128& a, size_t index) { ASSERT_MSG(esize * index < 128, "Invalid index"); switch (esize) { diff --git a/src/dynarmic/ir/ir_emitter.h b/src/dynarmic/ir/ir_emitter.h index db82ba38..62fa6323 100644 --- a/src/dynarmic/ir/ir_emitter.h +++ b/src/dynarmic/ir/ir_emitter.h @@ -237,6 +237,8 @@ public: U8 SM4AccessSubstitutionBox(const U8& a); U128 SHA256Hash(const U128& x, const U128& y, const U128& w, bool part1); + U128 SHA256MessageSchedule0(const U128& x, const U128& y); + U128 SHA256MessageSchedule1(const U128& x, const U128& y, const U128& z); UAny VectorGetElement(size_t esize, const U128& a, size_t index); U128 VectorSetElement(size_t esize, const U128& a, size_t index, const UAny& elem); diff --git a/src/dynarmic/ir/opcodes.inc b/src/dynarmic/ir/opcodes.inc index f72f22ac..d4ca86f2 100644 --- a/src/dynarmic/ir/opcodes.inc +++ b/src/dynarmic/ir/opcodes.inc @@ -274,6 +274,8 @@ OPCODE(SM4AccessSubstitutionBox, U8, U8 // SHA instructions OPCODE(SHA256Hash, U128, U128, U128, U128, U1 ) +OPCODE(SHA256MessageSchedule0, U128, U128, U128 ) +OPCODE(SHA256MessageSchedule1, U128, U128, U128, U128 ) // Vector instructions OPCODE(VectorGetElement8, U8, U128, U8 )