Solving Protohackers challenges in Rust

Intro

Protohackers is a casual programming challenge in which you create servers for network protocols.

These require you to host a publicly reachable server - since the college network is IPv4 behind a NAT, I couldn't try this until I came home (where I have IPv6) for Diwali.

Since most of my projects are small, I try to avoid dependencies. This makes it somewhat difficult when I try to contribute to larger projects - I need to learn about the crates they use. I thought Protohackers would be a good opportunity to learn some of the popular crates - this will be the theme throughout this post.

I'm skipping problem 4 because it's simple and can be solved in 30 LOC with just std.

Problem 0: Smoke Test

Link

The first problem asks you to implement the TCP Echo Service from RFC 862 - you simply have to send back whatever you receive. The actual task is simple:

use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};

fn echo(mut stream: TcpStream) {
let mut buf = vec![];
let _bytes_read = stream.read_to_end(&mut buf).unwrap();

stream.write(&buf).unwrap();
}

However,

Make sure you don't mangle binary data, and that you can handle at least 5 simultaneous clients

The easiest way is to just spawn a thread for every connection:

fn main() {
let listener = TcpListener::bind("[::]:8000").unwrap();
for stream in listener.incoming() {
std::thread::spawn(|| echo(stream.unwrap()));
}
}

This does pass the tests, but it's not ideal at all - we want a thread pool. The Book teaches you how to implement one, but our objective here is cratemaxxing - so let's look at rayon:

Rayon is a data-parallelism library for Rust. It is extremely lightweight and makes it easy to convert a sequential computation into a parallel one.

Rayon has a thread pool implementation, but its star feature is the high-level parallel constructs. Changing main to use a parallel iterator:

fn main() {
let listener = TcpListener::bind("[::]:8000").unwrap();
listener.incoming().par_bridge().for_each(|stream| {
echo(stream.unwrap());
});
}

Problem 1: Prime Time

Link

The task here is to respond to requests asking whether a number is prime, communicating via newline-separated JSON objects.

A conforming request object has the required field method, which must always contain the string "isPrime", and the required field number, which must contain a number. Any JSON number is a valid number, including floating-point values.

For this we use serde and serde-jsonlines. Serde provides #[derive] macros which implement {de}serialization, which we can enable using the derive feature of serde:

Cargo.toml
[dependencies]
serde = { version = "1", features = ["derive"] }
src/main.rs
use serde::{Deserialize, Serialize};
#[derive(Deserialize)]
struct Request {
method: String,
number: f64,
}

#[derive(Serialize)]
struct Response {
method: String,
prime: bool,
}

To use serde-jsonlines we import its BufReadExt and WriteExt extension traits. These traits extend the BufRead and Write traits to add the json_lines and write_json_lines methods respectively. The json_lines method returns an iterator over JSON values, and the other one is self explanatory. To use these traits we need to wrap the stream in BufReader and LineWriter respectively.

Ok, the LineWriter thing is not strictly true, but I couldn't get it to pass the borrow checker by writing directly to the owned stream, and buffering till newlines is nice anyway. The code for this:

use std::io::{BufReader, LineWriter};
fn handle(stream: TcpStream) {
let reader = BufReader::new(&stream);
let mut writer = LineWriter::new(&stream);
let lines = reader.json_lines::<Request>();
for deserialized in lines {
match deserialized {
Ok(request) => {
if request.method.as_str() != "isPrime" {
writer
.write_json_lines([Response {
method: String::from("malformed"),
prime: false,
}])
.unwrap();
break;
} else {
writer
.write_json_lines([Response {
method: String::from("isPrime"),
prime: is_prime(request.number),
}])
.unwrap();
}
}
Err(_) => {
writer
.write_json_lines([Response {
method: String::from("malformed"),
prime: false,
}])
.unwrap();
break;
}
};
}
}

main is the same as before:

fn main() {
let listener = TcpListener::bind("[::]:8000").unwrap();
listener.incoming().par_bridge().for_each(|stream| {
handle(stream.unwrap());
});
}

This part is not interesting, but for the sake of completeness - I used the standard algorithm for checking primality:

fn is_prime(number: f64) -> bool {
if number.is_nan()
|| number.is_sign_negative()
|| number.floor() != number
|| number == 0.0
|| number == 1.0
{
return false;
}

let max = number.sqrt().floor();
assert_eq!(max, (max as u128) as f64);
let (number, max) = (number as u128, max as u128);
if number != 2 && number % 2 == 0 {
return false;
}
for test in (3..=max).step_by(2) {
if number % test == 0 {
return false;
}
}
true
}

Problem 2: Means to an end

Link

The task here is to maintain a database for each client, who will send requests to insert and query it. This one uses a binary format, so let's look at binrw:

binrw helps you write maintainable & easy-to-read declarative binary data readers and writers using ✨macro magic✨.

The format is simple: check the first byte, then parse the next 8 bytes as two 32-bit big endian integers. This can be concisely expressed by binrw:

use binrw::BinRead;

#[derive(BinRead)]
#[br(big)]
enum Message {
#[br(magic(b'I'))]
Insert { timestamp: i32, price: i32 },
#[br(magic(b'Q'))]
Query { mintime: i32, maxtime: i32 },
}
let message = Message::read(&mut stream);

However, we can't read directly from a TcpStream:

error[E0277]: the trait bound `TcpStream: Seek` is not satisfied
--> src/main.rs:16:33
|
16 | let message = Message::read(&mut stream);
| ------------- ^^^^^^^^^^^ the trait `Seek` is not implemented for `TcpStream`
| |
| required by a bound introduced by this call
|
= help: the following other types implement trait `Seek`:
Box<S>
binrw::io::BufReader<T>
NoSeek<T>
File
std::io::BufReader<R>
BufWriter<W>
TakeSeek<T>
Arc<File>
and 4 others
note: required by a bound in `binrw::BinRead::read`
--> /home/chinmay/.local/share/cargo/registry/src/index.crates.io-6f17d22bba15001f/binrw-0.13.1/src/binread/mod.rs:57:23
|
57 | fn read<R: Read + Seek>(reader: &mut R) -> BinResult<Self>
| ^^^^ required by this bound in `BinRead::read`

TcpStream doesn't implement Seek - which makes sense, you can't go back in a TCP stream. Fortunately, binrw provides NoSeek:

A wrapper that provides a limited implementation of Seek for unseekable Read and Write streams. This is useful when reading or writing from unseekable streams where binrw does not actually need to seek to successfully parse or write the data.

Let's try this then:

let mut reader = NoSeek::new(&stream);
let mut writer = NoSeek::new(&stream);
let message = Message::read(&mut reader).unwrap();

This panics!:

╺━━━━━━━━━━━━━━━━━━━━┅ Backtrace ┅━━━━━━━━━━━━━━━━━━━━╸

0: Error: seek on unseekable file
rewinding after a failure
1: bad magic at 0x24: 81

╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸

If we remove the unwrap() and add a dbg!(&message), we see that it successfully parses Insert messages before failing. This is because when it encounters a Query message, it fails on the first variant and tries to rewind the stream to try the second one - which fails because NoSeek only provides a dummy Seek implementation. One of the developers explained the solution:

binrw::io::BufReader<binrw::io::NoSeek<TcpStream>>. NoSeek enables the Seek trait on BufReader; the Seek trait on BufReader tries to do relative seeks within its own buffer, so long as the seeked-to position is within the bounds of the buffer, it will not call NoSeek::seek, and so it will not fail rewinding

This is the handle() function after applying the fix:

fn handle(stream: TcpStream) {
let mut db = BTreeMap::new();
let mut reader = BufReader::new(NoSeek::new(&stream));
let mut writer = NoSeek::new(&stream);
while let Ok(message) = Message::read(&mut reader) {
match message {
Message::Insert { timestamp, price } => {
db.insert(timestamp, price);
}
Message::Query { mintime, maxtime } => {
if mintime > maxtime {
writer.write_be(&0).unwrap();
continue;
}
let mut sum = 0.0;
let mut n = 0;
for (_, &price) in db.range(mintime..=maxtime) {
sum += price as f64;
n += 1;
}
let average = sum / n.max(1) as f64;
let average = average as i32;
writer.write_be(&average).unwrap();
}
}
}
}

Problem 3: Budget Chat

Link

Tokio's chat server example is extremely similar and well commented, so I'll skip the explanation. We just need to make some changes.

Change some messages

To comply with the spec, the "user has joined" and "user has left" messages need to start with an asterisk.

List all existing users to a new joinee

For this we need to keep track of usernames in the common state:

struct Shared {
peers: HashMap<SocketAddr, (Tx, String)>,
}

Then, before registering the new user, we send them the list of existing users:

let mut message = String::from("* The room contains:");
for (_, existing_name) in state.lock().await.peers.values() {
message += &format!(" {existing_name},");
}
lines.send(message).await?;

Disallow illegal usernames

The first message from a client sets the user's name, which must contain at least 1 character, and must consist entirely of alphanumeric characters (uppercase, lowercase, and digits).
Implementations may limit the maximum length of a name, but must allow at least 16 characters. Implementations may choose to either allow or reject duplicate names.
If the user requests an illegal name, the server may send an informative error message to the client, and the server must disconnect the client, without sending anything about the illegal user to any other clients.

The check is straightforward:

fn is_legal(name: &str) -> bool {
if name.len() > 20 {
return false;
}
for char in name.chars() {
if !char.is_ascii_alphanumeric() {
return false;
}
}
true
}

Then, before listing the existing users to the new user, we check if the username is legal:

if !is_legal(&username) {
eprintln!("Illegal username {username} from {addr}. Client disconnected.");
return Ok(());
}

Problem 5: Mob in the Middle

Link

The task here is to intercept messages between clients and a server, and replace "Boguscoin" adresses with a mob boss's address.

Replacing the address

Your server will rewrite Boguscoin addresses, in both directions, so that they are always changed to Tony's address instead. A substring is considered to be a Boguscoin address if it satisfies all of:

The last two points are about boundaries of the pattern. The regex crate has word boundary assertions, but we don't want the pattern to stop at things like hyphens. For this, we need to use lookaround which is not supported by the regex crate - we need to look for something else. The fancy-regex crate provides lookaround, so we can write our pattern like this:

fancy_regex::Regex::new(r#"(?<=^|\s)7[[:alnum:]]{25,34}(?=$|\s)"#);

Regex compilation is expensive, so we need to make sure to compile it only once. Earlier, we'd have to use a crate like lazy_static, but OnceLock is a part of std Rust 1.70.0 onwards:

#![feature(once_cell_try)]
use fancy_regex::Regex;
fn replace_addr(message: &str) -> Cow<'_, str> {
static REGEX: OnceLock<Regex> = OnceLock::new();
REGEX
.get_or_try_init(|| Regex::new(r#"(?<=^|\s)7[[:alnum:]]{25,34}(?=$|\s)"#))
.unwrap()
.replace_all(&message, "7YWHMfk9JZe0LM0g1ZauHuiSxhI")
}

Event loop

Since we need to watch two connections for each client, we'll have to use an event loop. Let's look at polling.

First, we need to set the streams to nonblocking mode:

fn handle(mut client: TcpStream) -> io::Result<()> {
let mut server = TcpStream::connect(("chat.protohackers.com", 16963))?;
client.set_nonblocking(true)?;
server.set_nonblocking(true)?;

This will result in read, write, recv and send operations becoming nonblocking, i.e., immediately returning from their calls.

Then, we create a Poller which waits for I/O events, use arbitrary keys to identify the client and server sockets, and add the sockets with their keys to the Poller. We use add_with_mode instead of add because the default mode is one-shot, and we want to continuously receive events on the sockets. We also need to be careful to delete the sockets from the Poller before they are dropped as per add_with_mode's safety notice.

    let poller = Poller::new()?;
let (client_key, server_key) = (42, 43);
unsafe {
// Safety: we delete `client` and `server` from `poller` before they are dropped
// https://docs.rs/polling/latest/polling/struct.Poller.html#safety-1
poller.add_with_mode(
&client,
Event::readable(client_key).with_interrupt(),
PollMode::Edge,
)?;
poller.add_with_mode(
&server,
Event::readable(server_key).with_interrupt(),
PollMode::Edge,
)?;
}
/*
event loop here
*/
poller.delete(&client)?;
poller.delete(&server)?;

Then, we write the event loop - on every iteration of the loop, we clear the event queue, wait for events, then process all the events received. Since the events are identified by their key, we can process them accordingly.

We break out of the loop whenever we read 0 bytes, because this indicates a closed stream. On breaking out of the loop, both the client and server sockets are deleted from the Poller (see above) and both the streams are closed when they are dropped.

    let mut events = Events::new();

let (mut server_buf, mut client_buf) = ([0u8; 2000], [0u8; 2000]);
let (mut server_read, mut client_read) = (0, 0);
'outer: loop {
events.clear();
poller.wait(&mut events, None)?;

for event in events.iter() {
if event.key == server_key {
let bytes_read = server.read(&mut server_buf[server_read..])?;
if bytes_read == 0 {
break 'outer;
} else {
server_read += bytes_read;
}
if server_buf[server_read - 1] == b'\n' {
let message = std::str::from_utf8(&server_buf[..server_read]).unwrap();
client.write(&replace_addr(message).as_bytes())?;
client.flush()?;
server_read = 0;
}
} else if event.key == client_key {
let bytes_read = client.read(&mut client_buf[client_read..])?;
if bytes_read == 0 {
break 'outer;
} else {
client_read += bytes_read;
}
if client_buf[client_read - 1] == b'\n' {
let message = std::str::from_utf8(&client_buf[..client_read]).unwrap();
server.write(&replace_addr(message).as_bytes())?;
server.flush()?;
client_read = 0;
}
}
if event.is_interrupt() {
break 'outer;
}
}
}