#include "ArticBaseServer.hpp" #include "ArticBaseFunctions.hpp" #include #include #include #include #include #include #include #include #include #include #include #include extern int transferedBytes; ArticBaseServer::ArticBaseServer(int sock_fd) { socketFd = sock_fd; LightEvent_Init(&newPendingRequest, ResetType::RESET_ONESHOT); for (int i = 0; i < requestHandlers.size(); i++) { requestHandlers[i] = new RequestHandler(this, i); } } ArticBaseServer::~ArticBaseServer() { for (int i = 0; i < requestHandlers.size(); i++) { delete requestHandlers[i]; } if (socketFd >= 0) { int fd = socketFd; socketFd = -1; shutdown(fd, SHUT_RDWR); close(fd); } } void ArticBaseServer::Serve() { ArticBaseCommon::RequestPacket req; std::array params; int retryCount = 0; while (run && !stopQueried) { if (!Read(socketFd, &req, sizeof(req))) { if (stopQueried) { break; } if (run) { logger.Error("Server: Error reading from socket"); svcSleepThread(1000000000); if (++retryCount == 3) { break; } } continue; } std::array methodArray = {0}; memcpy(methodArray.data(), req.method.data(), req.method.size()); // Process special method now, delegate otherwise if (methodArray[0] == '$') { logger.Debug("Server: Processing %s (rID %d)", methodArray.data(), req.requestID); ArticBaseCommon::DataPacket resp{}; resp.requestID = req.requestID; std::string_view method(methodArray.data()); if (method == "$PING") { // Do nothing } else if (method == "$VERSION") { strcpy(resp.dataRaw, VERSION); } else if (method == "$PORTS") { snprintf(resp.dataRaw, sizeof(resp.dataRaw), "%d,%d,%d,%d", SERVER_PORT + 1, SERVER_PORT + 2, SERVER_PORT + 3, SERVER_PORT + 4); } else if (method == "$MAXSIZE") { snprintf(resp.dataRaw, sizeof(resp.dataRaw), "%d", ArticBaseServer::MAX_WORK_BUF_SIZE); } else if (method == "$MAXPARAM") { snprintf(resp.dataRaw, sizeof(resp.dataRaw), "%d", ArticBaseServer::MAX_PARAM_AMOUNT); } else if (method == "$READY") { bool ready = true; for (int i = 0; i < requestHandlers.size(); i++) { if (!requestHandlers[i]->ready) { ready = false; } } snprintf(resp.dataRaw, sizeof(resp.dataRaw), "%d", ready ? 1 : 0); } else if (method == "$STOP") { stopQueried = true; } else { logger.Error("Server: Method not found: %s", methodArray.data()); } if (!Write(socketFd, &resp, sizeof(resp))) { if (run) logger.Error("Server: Error writing to socket"); continue; } } else { u32 readParams = 0; if (req.parameterCount) { if (req.parameterCount <= MAX_PARAM_AMOUNT) { if (!Read(socketFd, params.data(), req.parameterCount * sizeof(ArticBaseCommon::RequestParameter))) { if (run) logger.Error("Server: Error reading from socket"); continue; } readParams = req.parameterCount; } else { logger.Error("Server: Too many parameters in request"); readParams = 0; } } { CTRPluginFramework::Lock l(pendingRequestsMutex); pendingRequests.push(Request{.reqPacket = req, .reqParameters = std::vector(params.begin(), params.begin() + readParams)}); } LightEvent_Signal(&newPendingRequest); } } Stop(); } void ArticBaseServer::QueryStop() { stopQueried = true; if (socketFd >= 0) { int fd = socketFd; socketFd = -1; shutdown(fd, SHUT_RDWR); close(fd); } } bool ArticBaseServer::SetNonBlock(int sockFD, bool nonBlocking) { int flags = fcntl(sockFD, F_GETFL, 0); if (flags < 0) { return false; } if (nonBlocking) { flags |= O_NONBLOCK; } else { flags &= ~O_NONBLOCK; } int res = fcntl(sockFD, F_SETFL, flags); if (res < 0) { return false; } return true; } void ArticBaseServer::Stop() { if (!run) return; run = false; if (socketFd >= 0) { int fd = socketFd; socketFd = -1; shutdown(fd, SHUT_RDWR); close(fd); } for (int i = 0; i < requestHandlers.size(); i++) { RequestHandler& handler = *requestHandlers[i]; handler.run = false; if (handler.accept_fd >= 0) { int fd = handler.accept_fd; handler.accept_fd = -1; shutdown(fd, SHUT_RDWR); close(fd); } if (handler.listen_fd >= 0) { int fd = handler.listen_fd; handler.listen_fd = -1; shutdown(fd, SHUT_RDWR); close(fd); } } while(true) { bool allStopped = true; for (int i = 0; i < requestHandlers.size(); i++) { allStopped = allStopped && threadGetExitCode(requestHandlers[i]->thread) == 1; } if (allStopped) break; svcSleepThread(1000000); LightEvent_Signal(&newPendingRequest); } } bool ArticBaseServer::Read(int& sockFD, void* buffer, size_t size) { size_t read_bytes = 0; while (read_bytes != size) { int new_read = recv(sockFD, (void*)((uintptr_t)buffer + read_bytes), size - read_bytes, 0); if (new_read < 0) { if (errno == EWOULDBLOCK) { svcSleepThread(1000000); continue; } read_bytes = 0; break; } transferedBytes += new_read; read_bytes += new_read; } return read_bytes == size; } bool ArticBaseServer::Write(int& sockFD, void* buffer, size_t size) { size_t write_bytes = 0; while (write_bytes != size) { int new_written = send(sockFD, (void*)((uintptr_t)buffer + write_bytes), size - write_bytes, 0); if (new_written < 0) { if (errno == EWOULDBLOCK) { svcSleepThread(1000000); continue; } write_bytes = 0; break; } transferedBytes += new_written; write_bytes += new_written; } return write_bytes == size; } size_t ArticBaseServer::RecvFrom(int& sockFD, void* buffer, size_t size, void* addr, void* addr_size) { while (true) { int new_read = recvfrom(sockFD, buffer, size, 0, (sockaddr*)addr, (socklen_t*)addr_size); if (new_read < 0) { if (errno == EWOULDBLOCK) { svcSleepThread(1000000); continue; } return 0; } transferedBytes += new_read; return new_read; } } size_t ArticBaseServer::SendTo(int& sockFD, void* buffer, size_t size, void* addr, void* addr_size) { socklen_t addr_len = *(socklen_t*)addr_size; while (true) { int new_written = sendto(sockFD, buffer, size, 0, (sockaddr*)addr, addr_len); if (new_written < 0) { if (errno == EWOULDBLOCK) { svcSleepThread(1000000); continue; } return 0; } transferedBytes += new_written; return new_written; } } ArticBaseServer::RequestHandler::RequestHandler(ArticBaseServer* serv, int id) { server = serv; this->id = id; workBufferSize = ArticBaseServer::MAX_WORK_BUF_SIZE; workBuffer = linearAlloc(workBufferSize); if (workBuffer == nullptr) { logger.Error("Worker %d: Failed to allocate work buffer", id); } s32 prio = 0; svcGetThreadPriority(&prio, CUR_THREAD_HANDLE); thread = threadCreate(RequestHandler::HandleThread, this, 0x1000, prio - 1, -2, false); } ArticBaseServer::RequestHandler::~RequestHandler() { threadJoin(thread, U64_MAX); threadFree(thread); linearFree(workBuffer); } void ArticBaseServer::RequestHandler::HandleThread(void* arg) { RequestHandler* own = (RequestHandler*)arg; own->Serve(); if (own->accept_fd >= 0) { int fd = own->accept_fd; shutdown(fd, SHUT_RDWR); close(fd); } logger.Debug("Worker %d: Exited", own->id); threadExit(1); } void ArticBaseServer::RequestHandler::Serve() { if (!workBuffer) { return; } struct sockaddr_in servaddr = {0}; int res; listen_fd = socket(AF_INET, SOCK_STREAM, 0); if (listen_fd < 0) { logger.Error("Worker %d: Cannot create socket", id); return; } if (!ArticBaseServer::SetNonBlock(listen_fd, true)) { logger.Error("Worker %d: Cannot set non-block", id); close(listen_fd); listen_fd = -1; return; } servaddr.sin_family = AF_INET; servaddr.sin_addr.s_addr = htonl(INADDR_ANY); servaddr.sin_port = htons(SERVER_PORT + id + 1); res = bind(listen_fd, (struct sockaddr *) &servaddr, sizeof(servaddr)); if (res < 0) { logger.Error("Worker %d: Failed to bind() to port %d", id, SERVER_PORT + id + 1); close(listen_fd); listen_fd = -1; return; } res = listen(listen_fd, 1); if (res < 0) { if (run) { logger.Error("Worker %d: Failed to listen()", id); } close(listen_fd); listen_fd = -1; return; } struct in_addr host_id; host_id.s_addr = gethostid(); logger.Debug("Worker %d: Listening on: %s:%d", id, inet_ntoa(host_id), SERVER_PORT + id + 1); struct sockaddr_in peeraddr = {0}; socklen_t peeraddr_len = sizeof(peeraddr); ready = true; while (true) { accept_fd = accept(listen_fd, (struct sockaddr *) &peeraddr, &peeraddr_len); if (accept_fd < 0 || peeraddr_len == 0) { if (errno == EWOULDBLOCK && run) { svcSleepThread(10000000); continue; } if (run) { logger.Error("Worker %d: Failed to accept()", id); } close(listen_fd); listen_fd = -1; return; } break; } close(listen_fd); listen_fd = -1; logger.Debug("Worker %d: Accepted %s:%d", id, inet_ntoa(peeraddr.sin_addr), ntohs(peeraddr.sin_port)); if (!ArticBaseServer::SetNonBlock(accept_fd, true)) { logger.Error("Worker %d: Cannot set non-block", id); shutdown(accept_fd, SHUT_RDWR); close(accept_fd); accept_fd = -1; return; } while (true) { LightEvent_Wait(&server->newPendingRequest); if (!run) { break; } Request req; while (run) { { CTRPluginFramework::Lock l(server->pendingRequestsMutex); if (!server->pendingRequests.size()) break; req = server->pendingRequests.front(); server->pendingRequests.pop(); } ArticBaseCommon::DataPacket respPacket{}; respPacket.requestID = req.reqPacket.requestID; std::array methodArray = {0}; memcpy(methodArray.data(), req.reqPacket.method.data(), req.reqPacket.method.size()); auto it = ArticBaseFunctions::functionHandlers.find(std::string(methodArray.data())); if (it == ArticBaseFunctions::functionHandlers.end()) { respPacket.resp.articResult = ArticBaseCommon::ResponseMethod::ArticResult::METHOD_NOT_FOUND; logger.Error("Worker %d: Method not found: %s", id, methodArray.data()); if (!Write(accept_fd, &respPacket, sizeof(respPacket))) { if (run) logger.Error("Worker %d: Error writing to socket", id); return; } continue; } logger.Debug("Worker %d: Processing %s (rID %d)", id, methodArray.data(), req.reqPacket.requestID); MethodInterface mi(req, workBuffer, workBufferSize, accept_fd); it->second(mi); ArticBaseCommon::MethodState mState = mi.GetMethodState(); if (mState != ArticBaseCommon::MethodState::FINISHED) { respPacket.resp.articResult = ArticBaseCommon::ResponseMethod::ArticResult::METHOD_ERROR; logger.Error("Worker %d: %s method error: %d", id, methodArray.data(), mState); respPacket.resp.methodResult = (int)mState; if (!Write(accept_fd, &respPacket, sizeof(respPacket))) { if (run) logger.Error("Worker %d: Error writing to socket", id); return; } } else { respPacket.resp.articResult = ArticBaseCommon::ResponseMethod::ArticResult::SUCCESS; respPacket.resp.methodResult = mi.GetMethodReturnValue(); auto outBuffer = mi.GetOutputBuffer(); respPacket.resp.bufferSize = outBuffer.second; if (!Write(accept_fd, &respPacket, sizeof(respPacket))) { if (run) logger.Error("Worker %d: Error writing to socket", id); return; } if (outBuffer.second) if (!Write(accept_fd, outBuffer.first, outBuffer.second)) { if (run) logger.Error("Worker %d: Error writing to socket", id); return; } } } } } bool ArticBaseServer::MethodInterface::GetParameterS8(s8& out) { if (state != MethodState::PARSING_INPUT) return false; if (currParameter >= req.reqPacket.parameterCount) { state = MethodState::PARAMETER_COUNT_MISMATCH; return false; } ArticBaseCommon::RequestParameter& currParam = req.reqParameters[currParameter++]; if (currParam.type == ArticBaseCommon::RequestParameterType::IN_INTEGER_8) { out = *(s8*)currParam.data; return true; } state = MethodState::PARAMETER_TYPE_MISMATCH; return false; } bool ArticBaseServer::MethodInterface::GetParameterS16(s16& out) { if (state != MethodState::PARSING_INPUT) return false; if (currParameter >= req.reqPacket.parameterCount) { state = MethodState::PARAMETER_COUNT_MISMATCH; return false; } ArticBaseCommon::RequestParameter& currParam = req.reqParameters[currParameter++]; if (currParam.type == ArticBaseCommon::RequestParameterType::IN_INTEGER_16) { out = *(s16*)currParam.data; return true; } state = MethodState::PARAMETER_TYPE_MISMATCH; return false; } bool ArticBaseServer::MethodInterface::GetParameterS32(s32& out) { if (state != MethodState::PARSING_INPUT) return false; if (currParameter >= req.reqPacket.parameterCount) { state = MethodState::PARAMETER_COUNT_MISMATCH; return false; } ArticBaseCommon::RequestParameter& currParam = req.reqParameters[currParameter++]; if (currParam.type == ArticBaseCommon::RequestParameterType::IN_INTEGER_32) { out = *(s32*)currParam.data; return true; } state = MethodState::PARAMETER_TYPE_MISMATCH; return false; } bool ArticBaseServer::MethodInterface::GetParameterS64(s64& out) { if (state != MethodState::PARSING_INPUT) return false; if (currParameter >= req.reqPacket.parameterCount) { state = MethodState::PARAMETER_COUNT_MISMATCH; return false; } ArticBaseCommon::RequestParameter& currParam = req.reqParameters[currParameter++]; if (currParam.type == ArticBaseCommon::RequestParameterType::IN_INTEGER_64) { out = *(s64*)currParam.data; return true; } state = MethodState::PARAMETER_TYPE_MISMATCH; return false; } bool ArticBaseServer::MethodInterface::GetParameterBuffer(void*& outBuff, size_t& outSize) { if (state != MethodState::PARSING_INPUT) return false; if (currParameter >= req.reqPacket.parameterCount) { state = MethodState::PARAMETER_COUNT_MISMATCH; return false; } ArticBaseCommon::RequestParameter& currParam = req.reqParameters[currParameter++]; if (currParam.type == ArticBaseCommon::RequestParameterType::IN_SMALL_BUFFER) { outBuff = currParam.data; outSize = currParam.parameterSize; return true; } else if (currParam.type == ArticBaseCommon::RequestParameterType::IN_BIG_BUFFER) { // Perform request for buffer size_t bufferSize = *(size_t*)currParam.data; ArticBaseCommon::Buffer* buf = workBuffer.Reserve(currParam.bigBufferID, bufferSize); if (buf == nullptr) { state = MethodState::OUT_OF_MEMORY; return false; } ArticBaseCommon::DataPacket dataPacket{}; dataPacket.requestID = req.reqPacket.requestID; dataPacket.resp.articResult = ArticBaseCommon::ResponseMethod::ArticResult::PROVIDE_INPUT; dataPacket.resp.provideInputBufferID = currParam.bigBufferID; dataPacket.resp.bufferSize = (int)bufferSize; if (!Write(socketFD, &dataPacket, sizeof(dataPacket))) { state = MethodState::BIG_BUFFER_WRITE_FAIL; return false; } if (!Read(socketFD, &dataPacket, sizeof(dataPacket))) { state = MethodState::BIG_BUFFER_READ_FAIL; return false; } if (dataPacket.requestID != req.reqPacket.requestID || dataPacket.resp.articResult != ArticBaseCommon::ResponseMethod::ArticResult::PROVIDE_INPUT || dataPacket.resp.provideInputBufferID != currParam.bigBufferID || dataPacket.resp.bufferSize != (int)bufferSize) { state = MethodState::BIG_BUFFER_READ_FAIL; return false; } if (!Read(socketFD, buf->data, buf->bufferSize)) { state = MethodState::BIG_BUFFER_READ_FAIL; return false; } outBuff = buf->data; outSize = buf->bufferSize; return true; } state = MethodState::PARAMETER_TYPE_MISMATCH; return false; } ArticBaseCommon::Buffer* ArticBaseServer::MethodInterface::ReserveResultBuffer(u32 bufferID, size_t resultBuffSize) { if (state != MethodState::GENERATING_OUTPUT) { if (state == MethodState::PARSING_INPUT) { state = MethodState::UNEXPECTED_PARSING_INPUT; } return nullptr; } ArticBaseCommon::Buffer* buf = workBuffer.Reserve(bufferID, resultBuffSize); if (buf == nullptr) { state = MethodState::OUT_OF_MEMORY_OUTPUT; return nullptr; } return buf; } ArticBaseCommon::Buffer* ArticBaseServer::MethodInterface::ResizeLastResultBuffer(ArticBaseCommon::Buffer* buffer, size_t newSize) { if (state != MethodState::GENERATING_OUTPUT) { if (state == MethodState::PARSING_INPUT) { state = MethodState::UNEXPECTED_PARSING_INPUT; } return nullptr; } WorkBufferHandler::ResizeState resizeSate = workBuffer.ResizeLast(buffer, newSize); if (resizeSate != WorkBufferHandler::ResizeState::GOOD) { if (resizeSate == WorkBufferHandler::ResizeState::OUT_OF_MEMORY) state = MethodState::OUT_OF_MEMORY_OUTPUT; else state = MethodState::INTERNAL_METHOD_ERROR; return nullptr; } return buffer; } void ArticBaseServer::MethodInterface::FinishGood(int returnValue) { if (state == MethodState::GENERATING_OUTPUT) state = MethodState::FINISHED; if (state == MethodState::PARSING_INPUT) state = MethodState::UNEXPECTED_PARSING_INPUT; this->returnValue = returnValue; } void ArticBaseServer::MethodInterface::FinishInternalError() { if (state == MethodState::GENERATING_OUTPUT || state == MethodState::PARSING_INPUT) state = MethodState::INTERNAL_METHOD_ERROR; return; } ArticBaseServer::MethodInterface::WorkBufferHandler::ResizeState ArticBaseServer::MethodInterface::WorkBufferHandler::ResizeLast(ArticBaseCommon::Buffer* buffer, size_t newSize) { if (buffer->bufferSize == newSize) return ResizeState::GOOD; if ((uintptr_t)workBuffer + offset != (uintptr_t)buffer + buffer->bufferSize + sizeof(ArticBaseCommon::Buffer)) return ResizeState::INPUT_ERROR; int sizeDiff = newSize - buffer->bufferSize; if (newSize == 0) { // Remove the buffer completely sizeDiff -= sizeof(ArticBaseCommon::Buffer); } if (offset + sizeDiff > workBufferSize) return ResizeState::OUT_OF_MEMORY; offset += sizeDiff; if (newSize != 0) buffer->bufferSize = newSize; return ResizeState::GOOD; }