// Copyright Citra Emulator Project / Azahar Emulator Project // Licensed under GPLv2 or any later version // Refer to the license.txt file included. #include "artic_base_client.h" #include "common/assert.h" #include "common/logging/log.h" #include "chrono" #include "limits.h" #include "memory" #include "sstream" #ifdef _WIN32 #include #include #else #include #include #include #include #include #include #include #include #include #include #include #endif #ifdef _WIN32 #define WSAEAGAIN WSAEWOULDBLOCK #define WSAEMULTIHOP -1 // Invalid dummy value #define ERRNO(x) WSA##x #define GET_ERRNO WSAGetLastError() #define poll(x, y, z) WSAPoll(x, y, z); #define SHUT_RD SD_RECEIVE #define SHUT_WR SD_SEND #define SHUT_RDWR SD_BOTH #else #define ERRNO(x) x #define GET_ERRNO errno #define closesocket(x) close(x) #endif namespace Network::ArticBase { using namespace std::chrono_literals; bool Client::Request::AddParameterS8(s8 parameter) { if (parameters.size() >= max_param_count) { LOG_ERROR(Network, "Too many parameters added to method: {}", method_name); return false; } auto& param = parameters.emplace_back(); param.type = ArticBaseCommon::RequestParameterType::IN_INTEGER_8; std::memcpy(param.data, ¶meter, sizeof(s8)); return true; } bool Client::Request::AddParameterS16(s16 parameter) { if (parameters.size() >= max_param_count) { LOG_ERROR(Network, "Too many parameters added to method: {}", method_name); return false; } auto& param = parameters.emplace_back(); param.type = ArticBaseCommon::RequestParameterType::IN_INTEGER_16; std::memcpy(param.data, ¶meter, sizeof(s16)); return true; } bool Client::Request::AddParameterS32(s32 parameter) { if (parameters.size() >= max_param_count) { LOG_ERROR(Network, "Too many parameters added to method: {}", method_name); return false; } auto& param = parameters.emplace_back(); param.type = ArticBaseCommon::RequestParameterType::IN_INTEGER_32; std::memcpy(param.data, ¶meter, sizeof(s32)); return true; } bool Client::Request::AddParameterS64(s64 parameter) { if (parameters.size() >= max_param_count) { LOG_ERROR(Network, "Too many parameters added to method: {}", method_name); return false; } auto& param = parameters.emplace_back(); param.type = ArticBaseCommon::RequestParameterType::IN_INTEGER_64; std::memcpy(param.data, ¶meter, sizeof(s64)); return true; } bool Client::Request::AddParameterBuffer(const void* buffer, size_t size) { if (parameters.size() >= max_param_count) { LOG_ERROR(Network, "Too many parameters added to method: {}", method_name); return false; } auto& param = parameters.emplace_back(); if (size <= sizeof(param.data)) { param.type = ArticBaseCommon::RequestParameterType::IN_SMALL_BUFFER; std::memcpy(param.data, buffer, size); param.parameterSize = static_cast(size); } else { param.type = ArticBaseCommon::RequestParameterType::IN_BIG_BUFFER; param.bigBufferID = static_cast(pending_big_buffers.size()); s32 size_32 = static_cast(size); std::memcpy(param.data, &size_32, sizeof(size_32)); pending_big_buffers.push_back(std::make_pair(buffer, size)); } return true; } Client::Request::Request(u32 request_id, const std::string& method, size_t max_params) { method_name = method; max_param_count = max_params; request_packet.requestID = request_id; std::memcpy(request_packet.method.data(), method.data(), std::min(request_packet.method.size(), method.size())); } void Client::UDPStream::Start() { thread_run = true; handle_thread = std::thread(&Client::UDPStream::Handle, this); } void Client::UDPStream::Handle() { struct sockaddr_in* servaddr = reinterpret_cast(serv_sockaddr_in.data()); socklen_t serv_sockaddr_len = static_cast(serv_sockaddr_in.size()); memcpy(servaddr, client.GetServerAddr().data(), client.GetServerAddr().size()); servaddr->sin_port = htons(port); main_socket = ::socket(AF_INET, SOCK_DGRAM, 0); if (main_socket == static_cast(-1) || !thread_run) { LOG_ERROR(Network, "Failed to create socket"); return; } if (!SetNonBlock(main_socket, true) || !thread_run) { closesocket(main_socket); LOG_ERROR(Network, "Cannot set non-blocking socket mode"); return; } // Limit receive buffer so that packets don't get qeued and are dropped instead. int buffer_size_int = static_cast(buffer_size); if (::setsockopt(main_socket, SOL_SOCKET, SO_RCVBUF, reinterpret_cast(&buffer_size_int), sizeof(buffer_size_int)) || !thread_run) { closesocket(main_socket); LOG_ERROR(Network, "Cannot change receive buffer size"); return; } // Send data to server so that it knows client address. char zero = '\0'; int send_res = ::sendto(main_socket, &zero, sizeof(char), 0, reinterpret_cast(serv_sockaddr_in.data()), serv_sockaddr_len); if (send_res < 0 || !thread_run) { closesocket(main_socket); LOG_ERROR(Network, "Cannot send data to socket"); return; } ready = true; std::vector buffer(buffer_size); while (thread_run) { std::chrono::steady_clock::time_point before = std::chrono::steady_clock::now(); int packet_size = ::recvfrom( main_socket, reinterpret_cast(buffer.data()), static_cast(buffer.size()), 0, reinterpret_cast(serv_sockaddr_in.data()), &serv_sockaddr_len); if (packet_size > 0) { if (client.report_traffic_callback) { client.report_traffic_callback(packet_size); } buffer.resize(packet_size); { std::scoped_lock l(current_buffer_mutex); current_buffer = buffer; } } auto elapsed = std::chrono::steady_clock::now() - before; std::unique_lock lk(thread_cv_mutex); thread_cv.wait_for(lk, elapsed < read_interval ? (read_interval - elapsed) : std::chrono::microseconds(50)); } ready = false; closesocket(main_socket); } Client::~Client() { StopImpl(false); for (auto it = handlers.begin(); it != handlers.end(); it++) { Handler* h = *it; h->thread->join(); delete h; } if (ping_thread.joinable()) { ping_thread.join(); } SocketManager::DisableSockets(); } bool Client::Connect() { if (connected) return true; auto str_to_int = [](const std::string& str) -> int { char* pEnd = NULL; unsigned long ul = ::strtoul(str.c_str(), &pEnd, 10); if (*pEnd) return -1; return static_cast(ul); }; struct addrinfo hints, *addrinfo; memset(&hints, 0, sizeof(hints)); hints.ai_socktype = SOCK_STREAM; hints.ai_family = AF_INET; LOG_INFO(Network, "Starting Artic Client"); if (getaddrinfo(address.data(), NULL, &hints, &addrinfo) != 0) { LOG_ERROR(Network, "Failed to get server address"); SignalCommunicationError(); return false; } main_socket = ::socket(AF_INET, SOCK_STREAM, 0); if (main_socket == static_cast(-1)) { LOG_ERROR(Network, "Failed to create socket"); SignalCommunicationError(); return false; } if (!SetNonBlock(main_socket, true)) { shutdown(main_socket, SHUT_RDWR); closesocket(main_socket); LOG_ERROR(Network, "Cannot set non-blocking socket mode"); SignalCommunicationError(); return false; } struct sockaddr_in servaddr = {0}; servaddr.sin_family = AF_INET; servaddr.sin_addr.s_addr = ((struct sockaddr_in*)(addrinfo->ai_addr))->sin_addr.s_addr; servaddr.sin_port = htons(port); freeaddrinfo(addrinfo); memcpy(last_sockaddr_in.data(), &servaddr, last_sockaddr_in.size()); if (!ConnectWithTimeout(main_socket, &servaddr, sizeof(servaddr), 10)) { closesocket(main_socket); LOG_ERROR(Network, "Failed to connect"); SignalCommunicationError(); return false; } auto version = SendSimpleRequest("VERSION"); if (version.has_value()) { int version_value = str_to_int(*version); if (version_value != SERVER_VERSION) { shutdown(main_socket, SHUT_RDWR); closesocket(main_socket); LOG_ERROR(Network, "Incompatible server version: {}", version_value); SignalCommunicationError("\nIncompatible Artic Server version.\nCheck for updates " "to the Artic Server or Azahar."); return false; } } else { shutdown(main_socket, SHUT_RDWR); closesocket(main_socket); LOG_ERROR(Network, "Couldn't fetch server version."); SignalCommunicationError(); return false; } auto max_work_size = SendSimpleRequest("MAXSIZE"); int max_work_size_value = -1; if (max_work_size.has_value()) { max_work_size_value = str_to_int(*max_work_size); } if (max_work_size_value < 0) { shutdown(main_socket, SHUT_RDWR); closesocket(main_socket); LOG_ERROR(Network, "Couldn't fetch server work ram size"); SignalCommunicationError(); return false; } max_server_work_ram = max_work_size_value; auto max_params = SendSimpleRequest("MAXPARAM"); int max_param_value = -1; if (max_params.has_value()) { max_param_value = str_to_int(*max_params); } if (max_param_value < 0) { shutdown(main_socket, SHUT_RDWR); closesocket(main_socket); LOG_ERROR(Network, "Couldn't fetch server max params"); SignalCommunicationError(); return false; } max_parameter_count = max_param_value; auto worker_ports = SendSimpleRequest("PORTS"); if (!worker_ports.has_value()) { shutdown(main_socket, SHUT_RDWR); closesocket(main_socket); LOG_ERROR(Network, "Couldn't fetch server worker ports"); SignalCommunicationError(); return false; } std::vector ports; std::string str_port; std::stringstream ss_port(worker_ports.value()); while (std::getline(ss_port, str_port, ',')) { int port_curr = str_to_int(str_port); if (port_curr < 0 || port_curr > static_cast(USHRT_MAX)) { shutdown(main_socket, SHUT_RDWR); closesocket(main_socket); LOG_ERROR(Network, "Couldn't parse server worker ports"); SignalCommunicationError(); return false; } ports.push_back(static_cast(port_curr)); } if (ports.empty()) { shutdown(main_socket, SHUT_RDWR); closesocket(main_socket); LOG_ERROR(Network, "Couldn't parse server worker ports"); SignalCommunicationError(); return false; } for (int i = 0; i < 101; i++) { auto ready_server = SendSimpleRequest("READY"); if (!ready_server.has_value() || i == 100) { shutdown(main_socket, SHUT_RDWR); closesocket(main_socket); LOG_ERROR(Network, "Couldn't fetch server readiness"); SignalCommunicationError(); return false; } if (*ready_server == "1") break; std::this_thread::sleep_for(100ms); } ping_thread = std::thread(&Client::PingFunction, this); int i = 0; running_handlers = ports.size(); for (auto it = ports.begin(); it != ports.end(); it++) { handlers.push_back(new Handler(*this, static_cast(servaddr.sin_addr.s_addr), *it, i)); i++; } connected = true; return true; } std::shared_ptr Client::NewUDPStream( const std::string stream_id, size_t buffer_size, const std::chrono::milliseconds& read_interval) { auto req = NewRequest("#" + stream_id); auto resp = Send(req); if (!resp.has_value()) { return nullptr; } auto port_udp = resp->GetResponseS32(0); if (!port_udp.has_value()) { return nullptr; } udp_streams.push_back(std::make_shared(*this, static_cast(*port_udp), buffer_size, read_interval)); return udp_streams.back(); } void Client::StopImpl(bool from_error) { bool expected = false; if (!stopped.compare_exchange_strong(expected, true)) return; if (!from_error) { SendSimpleRequest("STOP"); } for (auto it = udp_streams.begin(); it != udp_streams.end(); it++) { it->get()->Stop(); } if (ping_thread.joinable()) { std::scoped_lock l2(ping_cv_mutex); ping_run = false; ping_cv.notify_one(); } // Stop handlers for (auto it = handlers.begin(); it != handlers.end(); it++) { Handler* handler = *it; handler->should_run = false; // Shouldn't matter if the socket is shut down twice shutdown(handler->handler_socket, SHUT_RDWR); closesocket(handler->handler_socket); } // Close main socket shutdown(main_socket, SHUT_RDWR); closesocket(main_socket); } std::optional> Client::Response::GetResponseBuffer(u32 buffer_id) const { if (!resp_data_buffer) return std::nullopt; char* resp_data_buffer_end = resp_data_buffer + resp_data_size; char* resp_data_buffer_start = resp_data_buffer; while (resp_data_buffer_start + sizeof(ArticBaseCommon::Buffer) < resp_data_buffer_end) { ArticBaseCommon::Buffer* curr_buffer = reinterpret_cast(resp_data_buffer_start); resp_data_buffer_start += sizeof(ArticBaseCommon::Buffer); if (curr_buffer->bufferID == buffer_id) { if (curr_buffer->data + curr_buffer->bufferSize <= resp_data_buffer_end) { return std::make_pair(curr_buffer->data, curr_buffer->bufferSize); } else { return std::nullopt; } } resp_data_buffer_start += curr_buffer->bufferSize; } return std::nullopt; } std::optional Client::Send(Request& request) { if (stopped) return std::nullopt; request.request_packet.parameterCount = static_cast(request.parameters.size()); PendingResponse resp(request); { std::scoped_lock l(recv_map_mutex); pending_responses[request.request_packet.requestID] = &resp; } auto respPacket = SendRequestPacket(request.request_packet, false, request.parameters); if (stopped || !respPacket.has_value()) { std::scoped_lock l(recv_map_mutex); pending_responses.erase(request.request_packet.requestID); return std::nullopt; } std::unique_lock cv_lk(resp.cv_mutex); resp.cv.wait(cv_lk, [&resp]() { return resp.is_done; }); return std::optional(std::move(resp.response)); } void Client::SignalCommunicationError(const std::string& msg) { StopImpl(true); LOG_CRITICAL(Network, "Communication error"); if (communication_error_callback) communication_error_callback(msg); } void Client::PingFunction() { // Max silence time => 7 secs interval + 3 secs wait + 10 seconds timeout = 25 seconds while (ping_run) { std::chrono::time_point last = last_sent_request; if (std::chrono::steady_clock::now() - last > std::chrono::seconds(7)) { if (ping_enabled) { auto ping_reply = SendSimpleRequest("PING"); if (!ping_reply.has_value()) { SignalCommunicationError(); break; } } else { last_sent_request = std::chrono::steady_clock::now(); } } std::unique_lock lk(ping_cv_mutex); ping_cv.wait_for(lk, std::chrono::seconds(3)); } } bool Client::ConnectWithTimeout(SocketHolder sockFD, void* server_addr, size_t server_addr_len, int timeout_seconds) { int res = ::connect(sockFD, (struct sockaddr*)server_addr, static_cast(server_addr_len)); if (res == -1 && ((GET_ERRNO == ERRNO(EINPROGRESS) || GET_ERRNO == ERRNO(EWOULDBLOCK)))) { struct timeval tv; fd_set fdset; FD_ZERO(&fdset); FD_SET(sockFD, &fdset); tv.tv_sec = timeout_seconds; tv.tv_usec = 0; int select_res = ::select(static_cast(sockFD + 1), NULL, &fdset, NULL, &tv); #ifdef _WIN32 if (select_res == 0) { return false; } #else bool select_good = false; if (select_res == 1) { int so_error; socklen_t len = sizeof so_error; getsockopt(sockFD, SOL_SOCKET, SO_ERROR, &so_error, &len); if (so_error == 0) { select_good = true; } } if (!select_good) { return false; } #endif // _WIN32 } else if (res == -1) { return false; } return true; } bool Client::SetNonBlock(SocketHolder sockFD, bool nonBlocking) { bool blocking = !nonBlocking; #ifdef _WIN32 unsigned long nonblocking = (blocking) ? 0 : 1; int ret = ioctlsocket(sockFD, FIONBIO, &nonblocking); if (ret == -1) { return false; } #else int flags = ::fcntl(sockFD, F_GETFL, 0); if (flags == -1) { return false; } flags &= ~O_NONBLOCK; if (!blocking) { // O_NONBLOCK flags |= O_NONBLOCK; } const int ret = ::fcntl(sockFD, F_SETFL, flags); if (ret == -1) { return false; } #endif return true; } bool Client::Read(SocketHolder sockFD, void* buffer, size_t size, const std::chrono::nanoseconds& timeout) { size_t read_bytes = 0; auto before = std::chrono::steady_clock::now(); while (read_bytes != size) { int new_read = ::recv(sockFD, (char*)((uintptr_t)buffer + read_bytes), (int)(size - read_bytes), 0); if (new_read < 0) { if (GET_ERRNO == ERRNO(EWOULDBLOCK) && (timeout == std::chrono::nanoseconds(0) || std::chrono::steady_clock::now() - before < timeout)) { continue; } read_bytes = 0; break; } if (report_traffic_callback && new_read) { report_traffic_callback(new_read); } read_bytes += new_read; } return read_bytes == size; } bool Client::Write(SocketHolder sockFD, const void* buffer, size_t size, const std::chrono::nanoseconds& timeout) { size_t write_bytes = 0; auto before = std::chrono::steady_clock::now(); while (write_bytes != size) { int new_written = ::send(sockFD, (const char*)((uintptr_t)buffer + write_bytes), (int)(size - write_bytes), 0); if (new_written < 0) { if (GET_ERRNO == ERRNO(EWOULDBLOCK) && (timeout == std::chrono::nanoseconds(0) || std::chrono::steady_clock::now() - before < timeout)) { continue; } write_bytes = 0; break; } if (report_traffic_callback && new_written) { report_traffic_callback(new_written); } write_bytes += new_written; } return write_bytes == size; } std::optional Client::SendRequestPacket( const ArticBaseCommon::RequestPacket& req, bool expect_response, const std::vector& params, const std::chrono::nanoseconds& read_timeout) { std::scoped_lock l(send_mutex); if (main_socket == static_cast(-1)) { return std::nullopt; } if (!Write(main_socket, &req, sizeof(req))) { LOG_WARNING(Network, "Failed to write to socket"); SignalCommunicationError(); return std::nullopt; } if (!params.empty()) { if (!Write(main_socket, params.data(), params.size() * sizeof(ArticBaseCommon::RequestParameter))) { LOG_WARNING(Network, "Failed to write to socket"); SignalCommunicationError(); return std::nullopt; } } ArticBaseCommon::DataPacket resp; if (expect_response) { if (!Read(main_socket, &resp, sizeof(resp), read_timeout)) { LOG_WARNING(Network, "Failed to read from socket"); SignalCommunicationError(); return std::nullopt; } if (resp.requestID != req.requestID) { return std::nullopt; } } last_sent_request = std::chrono::steady_clock::now(); return resp; } std::optional Client::SendSimpleRequest(const std::string& method) { ArticBaseCommon::RequestPacket req{}; req.requestID = GetNextRequestID(); const std::string final_method = "$" + method; if (final_method.size() > sizeof(req.method)) { return std::nullopt; } std::memcpy(req.method.data(), final_method.data(), final_method.size()); auto resp = SendRequestPacket(req, true, {}, std::chrono::seconds(10)); if (!resp.has_value() || resp->requestID != req.requestID) { return std::nullopt; } char respBody[sizeof(ArticBaseCommon::DataPacket::dataRaw) + 1] = {0}; std::memcpy(respBody, resp->dataRaw, sizeof(ArticBaseCommon::DataPacket::dataRaw)); return respBody; } Client::Handler::Handler(Client& _client, u32 _addr, u16 _port, int _id) : id(_id), client(_client), addr(_addr), port(_port) { thread = new std::thread( [](Handler* handler) { handler->RunLoop(); handler->should_run = false; if (--handler->client.running_handlers == 0) { handler->client.OnAllHandlersFinished(); } }, this); } void Client::Handler::RunLoop() { handler_socket = ::socket(AF_INET, SOCK_STREAM, 0); if (handler_socket == static_cast(-1)) { LOG_ERROR(Network, "Failed to create socket"); return; } if (!SetNonBlock(handler_socket, true)) { closesocket(handler_socket); client.SignalCommunicationError(); LOG_ERROR(Network, "Cannot set non-blocking socket mode"); return; } struct sockaddr_in servaddr = {0}; servaddr.sin_family = AF_INET; servaddr.sin_addr.s_addr = static_cast(addr); servaddr.sin_port = htons(port); if (!ConnectWithTimeout(handler_socket, &servaddr, sizeof(servaddr), 10)) { closesocket(handler_socket); LOG_ERROR(Network, "Failed to connect"); client.SignalCommunicationError(); return; } const auto signal_error = [&] { if (should_run) { client.SignalCommunicationError(); } }; ArticBaseCommon::DataPacket dataPacket; u32 retry_count = 0; while (should_run) { if (!client.Read(handler_socket, &dataPacket, sizeof(dataPacket))) { if (should_run) { LOG_WARNING(Network, "Failed to read from socket"); std::this_thread::sleep_for(100ms); if (++retry_count == 300) { signal_error(); break; } continue; } else { break; } } retry_count = 0; PendingResponse* pending_response; { std::scoped_lock l(client.recv_map_mutex); auto it = client.pending_responses.find(dataPacket.requestID); if (it == client.pending_responses.end()) { continue; } pending_response = it->second; } switch (dataPacket.resp.articResult) { case ArticBaseCommon::ResponseMethod::ArticResult::SUCCESS: { pending_response->response.articResult = dataPacket.resp.articResult; pending_response->response.methodResult = dataPacket.resp.methodResult; if (dataPacket.resp.bufferSize) { pending_response->response.resp_data_buffer = reinterpret_cast(operator new(dataPacket.resp.bufferSize)); ASSERT_MSG(pending_response->response.resp_data_buffer != nullptr, "ArticBase Handler: Cannot allocate buffer"); pending_response->response.resp_data_size = static_cast(dataPacket.resp.bufferSize); if (!client.Read(handler_socket, pending_response->response.resp_data_buffer, dataPacket.resp.bufferSize)) { signal_error(); } } } break; case ArticBaseCommon::ResponseMethod::ArticResult::METHOD_NOT_FOUND: { LOG_ERROR(Network, "Method {} not found by server", pending_response->request.method_name); pending_response->response.articResult = dataPacket.resp.articResult; } break; case ArticBaseCommon::ResponseMethod::ArticResult::PROVIDE_INPUT: { size_t bufferID = static_cast(dataPacket.resp.provideInputBufferID); if (bufferID >= pending_response->request.pending_big_buffers.size() || pending_response->request.pending_big_buffers[bufferID].second != static_cast(dataPacket.resp.bufferSize)) { LOG_ERROR(Network, "Method {} incorrect big buffer state {}", pending_response->request.method_name, bufferID); dataPacket.resp.articResult = ArticBaseCommon::ResponseMethod::ArticResult::METHOD_ERROR; if (client.Write(handler_socket, &dataPacket, sizeof(dataPacket))) { continue; } else { signal_error(); } } else { auto& buffer = pending_response->request.pending_big_buffers[bufferID]; if (client.Write(handler_socket, &dataPacket, sizeof(dataPacket))) { if (client.Write(handler_socket, buffer.first, buffer.second)) { continue; } else { signal_error(); } } else { signal_error(); } } } break; case ArticBaseCommon::ResponseMethod::ArticResult::METHOD_ERROR: default: { LOG_ERROR(Network, "Method {} error {}", pending_response->request.method_name, dataPacket.resp.methodResult); pending_response->response.articResult = dataPacket.resp.articResult; pending_response->response.methodState = static_cast(dataPacket.resp.methodResult); } break; } { std::scoped_lock l(client.recv_map_mutex); client.pending_responses.erase(dataPacket.requestID); } { std::scoped_lock lk(pending_response->cv_mutex); pending_response->is_done = true; pending_response->cv.notify_one(); } } should_run = false; shutdown(handler_socket, SHUT_RDWR); closesocket(handler_socket); } void Client::OnAllHandlersFinished() { // If no handlers are running, signal all pending requests so that // they don't become stuck. std::scoped_lock l(recv_map_mutex); for (auto& [id, response] : pending_responses) { std::scoped_lock l2(response->cv_mutex); response->is_done = true; response->cv.notify_one(); } pending_responses.clear(); } } // namespace Network::ArticBase