From 2eb174713b8f819f7e013e48efb2d0736c065ed9 Mon Sep 17 00:00:00 2001
From: Subv <subv2112@gmail.com>
Date: Wed, 14 Jun 2017 16:59:16 -0500
Subject: [PATCH] UDS: Use the ToDS and FromDS fields to properly calculate the
 AAD used during encryption.

---
 src/core/hle/service/nwm/uds_data.cpp | 47 ++++++++++++++++++---------
 1 file changed, 32 insertions(+), 15 deletions(-)

diff --git a/src/core/hle/service/nwm/uds_data.cpp b/src/core/hle/service/nwm/uds_data.cpp
index 280c73e05f..8c6742dba1 100644
--- a/src/core/hle/service/nwm/uds_data.cpp
+++ b/src/core/hle/service/nwm/uds_data.cpp
@@ -103,7 +103,8 @@ static std::array<u8, CryptoPP::AES::BLOCKSIZE> GenerateDataCCMPKey(
  * Generates the Additional Authenticated Data (AAD) for an UDS 802.11 encrypted data frame.
  * @returns a buffer with the bytes of the AAD.
  */
-static std::vector<u8> GenerateCCMPAAD(const MacAddress& sender, const MacAddress& receiver) {
+static std::vector<u8> GenerateCCMPAAD(const MacAddress& sender, const MacAddress& receiver,
+                                       const MacAddress& bssid, u16 frame_control) {
     // Reference: IEEE 802.11-2007
 
     // 8.3.3.3.2 Construct AAD (22-30 bytes)
@@ -113,20 +114,34 @@ static std::vector<u8> GenerateCCMPAAD(const MacAddress& sender, const MacAddres
     // Control field are masked to 0.
     struct {
         u16_be FC; // MPDU Frame Control field
-        MacAddress receiver;
-        MacAddress transmitter;
-        MacAddress destination;
+        MacAddress A1;
+        MacAddress A2;
+        MacAddress A3;
         u16_be SC; // MPDU Sequence Control field
     } aad_struct{};
 
-    // Default FC value of DataFrame | Protected | ToDS
-    constexpr u16 DefaultFrameControl = 0x0841;
-
-    aad_struct.FC = DefaultFrameControl;
+    constexpr u16 AADFrameControlMask = 0x8FC7;
+    aad_struct.FC = frame_control & AADFrameControlMask;
     aad_struct.SC = 0;
-    aad_struct.transmitter = sender;
-    aad_struct.receiver = receiver;
-    aad_struct.destination = receiver;
+
+    bool to_ds = (frame_control & (1 << 0)) != 0;
+    bool from_ds = (frame_control & (1 << 1)) != 0;
+    // In the 802.11 standard, ToDS = 1 and FromDS = 1 is a valid configuration,
+    // however, the 3DS doesn't seem to transmit frames with such combination.
+    ASSERT_MSG(to_ds != from_ds, "Invalid combination");
+
+    // The meaning of the address fields depends on the ToDS and FromDS fields.
+    if (from_ds) {
+        aad_struct.A1 = receiver;
+        aad_struct.A2 = bssid;
+        aad_struct.A3 = sender;
+    }
+
+    if (to_ds) {
+        aad_struct.A1 = bssid;
+        aad_struct.A2 = sender;
+        aad_struct.A3 = receiver;
+    }
 
     std::vector<u8> aad(sizeof(aad_struct));
     std::memcpy(aad.data(), &aad_struct, sizeof(aad_struct));
@@ -141,11 +156,12 @@ static std::vector<u8> GenerateCCMPAAD(const MacAddress& sender, const MacAddres
 static std::vector<u8> DecryptDataFrame(const std::vector<u8>& encrypted_payload,
                                         const std::array<u8, CryptoPP::AES::BLOCKSIZE>& ccmp_key,
                                         const MacAddress& sender, const MacAddress& receiver,
-                                        u16 sequence_number) {
+                                        const MacAddress& bssid, u16 sequence_number,
+                                        u16 frame_control) {
 
     // Reference: IEEE 802.11-2007
 
-    std::vector<u8> aad = GenerateCCMPAAD(sender, receiver);
+    std::vector<u8> aad = GenerateCCMPAAD(sender, receiver, bssid, frame_control);
 
     std::vector<u8> packet_number{0,
                                   0,
@@ -200,10 +216,11 @@ static std::vector<u8> DecryptDataFrame(const std::vector<u8>& encrypted_payload
 static std::vector<u8> EncryptDataFrame(const std::vector<u8>& payload,
                                         const std::array<u8, CryptoPP::AES::BLOCKSIZE>& ccmp_key,
                                         const MacAddress& sender, const MacAddress& receiver,
-                                        u16 sequence_number) {
+                                        const MacAddress& bssid, u16 sequence_number,
+                                        u16 frame_control) {
     // Reference: IEEE 802.11-2007
 
-    std::vector<u8> aad = GenerateCCMPAAD(sender, receiver);
+    std::vector<u8> aad = GenerateCCMPAAD(sender, receiver, bssid, frame_control);
 
     std::vector<u8> packet_number{0,
                                   0,