From b04153bd8d96cfec7e615acc9ea28bb5f8580c12 Mon Sep 17 00:00:00 2001 From: clfreville2 Date: Wed, 17 Jan 2024 09:10:18 +0100 Subject: [PATCH] Define the protocol --- README.md | 47 ++++++++++++++++++++++++ src/main.cpp | 95 ++++++++++++++++++++++++++++++++++--------------- src/network.cpp | 16 +++++++++ src/network.hpp | 9 +++++ src/runner.cpp | 13 ++++++- src/runner.hpp | 1 + 6 files changed, 151 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 0a7e43a..c77c7d0 100644 --- a/README.md +++ b/README.md @@ -4,3 +4,50 @@ planificador A sandbox execution environment for untrusted code. It acts as a front-end in front of Docker+Bubblewrap. Tasks are submitted using a ZeroMQ message queue, allowing quick scaling of the system. + +Protocol +-------- + +*planificador* receives messages from a ZeroMQ queue in binary format in big-endian. + +Executor bound +-------------- + +The first byte of the message is the message type. The following types are supported: + +* `0x00`: `SUBMIT` - Submit a new task to the system. +* `0x02`: `CANCEL` - Cancel a task. + +The following bytes are the payload of the message. The format of the payload depends on the message type. + +### SUBMIT + +- 32 bytes: Task ID +- 4 bytes: Image field length +- 4 bytes: Code length +- Image field length bytes: Image field +- Code length bytes: Code + +### CANCEL + +- 32 bytes: Task ID + +Client bound +------------ + +The first byte of the message is the message type. The following types are supported: + +* `0x01`: `APPEND_OUT` - Append text to the task's stdout. +* `0x02`: `APPEND_ERR` - Append text to the task's stderr. +* `0x03`: `EXITED` - The task has exited. + +## APPEND_OUT / APPEND_ERR + +- 32 bytes: Task ID +- 4 bytes: Text length +- Text length bytes: Text + +## EXITED + +- 32 bytes: Task ID +- 4 bytes: Exit code diff --git a/src/main.cpp b/src/main.cpp index 355b23c..04bc138 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -5,15 +5,21 @@ #include #include #include +#include #include #include -#include #include "config.hpp" #include "zmq_addon.hpp" -static constexpr uint32_t JOB_ID_LEN = 32; -static constexpr uint32_t MIN_MESSAGE_LEN = JOB_ID_LEN + sizeof(uint32_t) * 2; +static constexpr uint32_t MIN_SUBMIT_MESSAGE_LEN = sk::JOB_ID_LEN + sizeof(uint32_t) * 2; +static constexpr uint32_t MIN_CANCEL_MESSAGE_LEN = sk::JOB_ID_LEN + sizeof(uint32_t); + +static constexpr int SUBMIT_EXECUTOR_BOUND = 0; +static constexpr int CANCEL_EXECUTOR_BOUND = 1; +static constexpr int STDOUT_CLIENT_BOUND = 1; +static constexpr int STDERR_CLIENT_BOUND = 2; +static constexpr int EXIT_CLIENT_BOUND = 3; sk::runner_backend detect_backend() { const char *const argv[] = {"docker", "stats", "--no-stream", nullptr}; @@ -90,37 +96,68 @@ int main(int argc, char **argv) { zmq::socket_t sender(context, zmq::socket_type::push); sender.connect(config.queue.push_addr); + auto send = [&sender](int type, const std::string &jobId, const std::string &text) { +#ifndef NDEBUG + std::cout << "Result: `" << text << "`\n"; +#endif + auto [reply, reply_bytes] = sk::prepare_headers(sizeof(uint32_t) + text.size(), type, jobId); + sk::write_string(reply_bytes, text); + sender.send(reply, zmq::send_flags::none); + }; + while (true) { zmq::message_t request; - receiver.recv(request); - if (request.size() < MIN_MESSAGE_LEN) { - std::cerr << "Invalid request" << std::endl; - continue; - } + zmq::recv_result_t _ = receiver.recv(request); + const auto *message = static_cast(request.data()) + 1; + auto *message_bytes = static_cast(request.data()) + 1; + int message_type = static_cast(*static_cast(request.data())); + switch (message_type) { + case SUBMIT_EXECUTOR_BOUND: { + if (request.size() < MIN_SUBMIT_MESSAGE_LEN) { + std::cerr << "Invalid request\n"; + continue; + } + std::string jobId(message, sk::JOB_ID_LEN); + uint32_t imageLen = sk::read_uint32(message_bytes + sk::JOB_ID_LEN); + uint32_t codeLen = sk::read_uint32(message_bytes + sk::JOB_ID_LEN + sizeof(uint32_t)); - std::string jobId(static_cast(request.data()), JOB_ID_LEN); - uint32_t imageLen = sk::read_uint32(static_cast(request.data()) + JOB_ID_LEN); - uint32_t codeLen = sk::read_uint32(static_cast(request.data()) + JOB_ID_LEN + sizeof(uint32_t)); + if (request.size() < MIN_SUBMIT_MESSAGE_LEN + imageLen + codeLen) { + std::cerr << "Request is too short\n"; + continue; + } + std::string imageString(message + MIN_SUBMIT_MESSAGE_LEN, imageLen); + std::string requestString(message + MIN_SUBMIT_MESSAGE_LEN + imageLen, codeLen); - if (request.size() < MIN_MESSAGE_LEN + imageLen + codeLen) { - std::cerr << "Invalid request" << std::endl; - continue; - } - std::string imageString(static_cast(request.data()) + MIN_MESSAGE_LEN, imageLen); - std::string requestString(static_cast(request.data()) + MIN_MESSAGE_LEN + imageLen, codeLen); - - std::cout << "Executing " << codeLen << " bytes code.\n"; - sk::program program{jobId, requestString, imageString}; - sk::run_result result = runner.run_blocking(program); +#ifndef NDEBUG + std::cout << "Executing " << codeLen << " bytes code.\n"; +#endif + sk::program program{std::move(jobId), std::move(requestString), std::move(imageString)}; + sk::run_result result = runner.run_blocking(program); - std::cout << "Result: " << result.out << std::endl; - - // Send the job id, the exit code and result.out to sink - zmq::message_t reply(JOB_ID_LEN + sizeof(uint32_t) + result.out.size()); - memcpy(reply.data(), jobId.data(), JOB_ID_LEN); - sk::write_uint32(static_cast(reply.data()) + JOB_ID_LEN, result.exit_code); - memcpy(static_cast(reply.data()) + JOB_ID_LEN + sizeof(uint32_t), result.out.data(), result.out.size()); - sender.send(reply, zmq::send_flags::none); + if (!result.out.empty()) { + send(STDOUT_CLIENT_BOUND, program.name, result.out); + } + if (!result.err.empty()) { + send(STDERR_CLIENT_BOUND, program.name, result.err); + } + auto [reply, reply_bytes] = sk::prepare_headers(sizeof(uint32_t), EXIT_CLIENT_BOUND, program.name); + sk::write_uint32(reply_bytes, result.exit_code); + sender.send(reply, zmq::send_flags::none); + break; + } + case CANCEL_EXECUTOR_BOUND: { + if (request.size() < MIN_CANCEL_MESSAGE_LEN) { + std::cerr << "Invalid request\n"; + continue; + } + std::string jobId(message, sk::JOB_ID_LEN); + runner.kill_active(jobId); + break; + } + default: + std::cerr << "Invalid " << std::hex << message_type << " message type\n"; + break; + } } return 0; } diff --git a/src/network.cpp b/src/network.cpp index 6bc11f2..f51fc8b 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -1,5 +1,7 @@ #include "network.hpp" +#include + namespace sk { uint32_t read_uint32(const std::byte *buffer) { return static_cast(buffer[3]) | static_cast(buffer[2]) << 8 | static_cast(buffer[1]) << 16 | static_cast(buffer[0]) << 24; } void write_uint32(std::byte *buffer, uint32_t value) { @@ -8,4 +10,18 @@ void write_uint32(std::byte *buffer, uint32_t value) { buffer[2] = static_cast(value >> 8); buffer[3] = static_cast(value); } + +void write_string(std::byte *buffer, std::string_view text) { + auto size = static_cast(text.size()); + write_uint32(buffer, size); + memcpy(buffer, text.data(), size); +} + +std::tuple prepare_headers(size_t data_len, int type, std::string_view jobId) { + zmq::message_t reply(1 + JOB_ID_LEN + data_len); + auto *reply_bytes = static_cast(reply.data()); + *reply_bytes = static_cast(type); + memcpy(reply_bytes + 1, jobId.data(), JOB_ID_LEN); + return {std::move(reply), reply_bytes + 1 + JOB_ID_LEN}; +} } diff --git a/src/network.hpp b/src/network.hpp index 2fa45e5..8b75dc3 100644 --- a/src/network.hpp +++ b/src/network.hpp @@ -2,8 +2,17 @@ #include #include +#include +#include +#include namespace sk { + +static constexpr uint32_t JOB_ID_LEN = 32; + uint32_t read_uint32(const std::byte *buffer); void write_uint32(std::byte *buffer, uint32_t value); +void write_string(std::byte *buffer, std::string_view text); + +std::tuple prepare_headers(size_t data_len, int type, std::string_view jobId); } diff --git a/src/runner.cpp b/src/runner.cpp index f3972da..b855335 100644 --- a/src/runner.cpp +++ b/src/runner.cpp @@ -7,9 +7,9 @@ #include #include #include +#include #include #include -#include // Define a helper to throw a system error if a syscall fails static auto ensure = [](int res) -> void { @@ -165,6 +165,17 @@ run_result runner::run_blocking(const program &program) { return run_result{out, err, killed ? 124 : exit_code}; } +bool runner::kill_active(const std::string &jobId) { + std::lock_guard guard(active_jobs_mutex); + auto it = std::find_if(active_jobs.begin(), active_jobs.end(), [&jobId](const active_job &job) { return job.job_id == jobId; }); + if (it != active_jobs.end()) { + exit(*it); + active_jobs.erase(it); + return true; + } + return false; +} + void runner::exit_active_jobs() { std::lock_guard guard(active_jobs_mutex); for (const auto &job : active_jobs) { diff --git a/src/runner.hpp b/src/runner.hpp index 9d540ee..9fe7505 100644 --- a/src/runner.hpp +++ b/src/runner.hpp @@ -29,6 +29,7 @@ class runner { public: runner(runner_backend backend, const runner_config &config); run_result run_blocking(const program &program); + bool kill_active(const std::string &jobId); void exit_active_jobs(); private: