Add support for PROXY protocol.

This commit is contained in:
bʰedoh₂ swé 2025-02-12 12:12:22 +05:00
parent c0db9ff1a1
commit 3b68d8a097
3 changed files with 83 additions and 8 deletions

View File

@ -17,7 +17,7 @@ public class Main {
case "help" -> {
System.out.println("crab help - print this message.");
System.out.println("crab client <ip> <port> [nick] - connect to a server.");
System.out.println("crab server <port> - start a server.");
System.out.println("crab server <port> [PROXY protocol off/on] - start a server.");
}
case "client" -> {
CrabClient client;
@ -31,12 +31,20 @@ public class Main {
}
case "server" -> {
CrabServer server;
if (args.length > 1) {
boolean isProxied = false;
if (args.length > 2)
isProxied = args[2].equals("on") ? true : false;
try {
server = new CrabServer(Integer.parseInt(args[1]));
server = new CrabServer(Integer.parseInt(args[1]), isProxied);
} catch (NumberFormatException e) {
System.err.println("Port is not a number.");
return;
}
} else {
System.err.println("Now enough arguments.");
return;
}
server.run();
}
default -> {

View File

@ -3,6 +3,7 @@ import net.pixtaded.crab.common.Crab;
import net.pixtaded.crab.common.Logs;
import java.io.IOException;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.Scanner;
@ -11,6 +12,7 @@ public class CrabServer implements Crab {
private ServerSocket serverSocket;
private boolean isStopped = false;
private boolean isProxied = false;
private int port;
private final Database db;
public Logs cache = new Logs(0, "");
@ -19,9 +21,10 @@ public class CrabServer implements Crab {
this.db = new Database("data.db");
}
public CrabServer(int port) {
public CrabServer(int port, boolean isProxied) {
this.db = new Database("data.db");
this.port = port;
this.isProxied = isProxied;
}
@Override
@ -53,11 +56,28 @@ public class CrabServer implements Crab {
System.out.println("Enter a correct port number: ");
}
}
System.out.print("Enable PROXY protocol? (on/off): ");
while (true) {
String s = scanner.nextLine();
if (s.equals("on")) {
this.isProxied = true;
break;
}
if (s.equals("off")) {
this.isProxied = false;
break;
}
System.out.println("Enter either \"on\" or \"off\".");
}
}
private void listen() throws IOException {
Scanner scanner = new Scanner(System.in);
if (this.isProxied) {
serverSocket = new ServerSocket(port, 0, InetAddress.getLoopbackAddress());
} else {
serverSocket = new ServerSocket(port);
}
System.out.printf("Server successfully started! Listening on port %s.\nTo stop the server, type 'q'.\n", port);
ServerCLI cli = new ServerCLI(scanner, this);
new Thread(cli).start();
@ -82,4 +102,8 @@ public class CrabServer implements Crab {
public Database getDb() {
return db;
}
public boolean isProxied() {
return this.isProxied;
}
}

View File

@ -7,6 +7,7 @@ import net.pixtaded.crab.common.Util;
import java.io.*;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Date;
import static net.pixtaded.crab.common.PID.*;
@ -37,11 +38,40 @@ public class ServerThread implements Runnable {
socket.close();
return;
}
String address = socket.getInetAddress().getHostAddress();
if (PID[0] == 'P') {
if (!this.server.isProxied()) {
System.err.println(address + " tried to use PROXY despite it being off.");
socket.close();
return;
}
if (Arrays.equals(readUntilChar(' '),"ROXY".getBytes())) {
readUntilChar(' '); // proto
byte source[] = readUntilChar(' ');
address = new String(source);
readUntilChar(' '); // destination IP
readUntilChar(' '); // source port
readUntilChar('\r'); // destination port
if (input.read() != '\n') {
System.err.println("Invalid PROXY packet.");
socket.close();
return;
}
} else {
System.err.println("Invalid PROXY packet header.");
socket.close();
return;
}
PID = readPID();
if (PID.length == 0) {
socket.close();
return;
}
}
switch (PID[0]) {
case MESSAGE -> {
String msg = new String(input.readNBytes(4096), StandardCharsets.UTF_8).trim();
Date date = new Date();
String address = socket.getInetAddress().getHostAddress();
String s = Sanitizer.sanitizeString(msg, true);
String newContent = server.cache.content() + Sanitizer.formatMessage(date.getTime(), address, s);
@ -68,6 +98,19 @@ public class ServerThread implements Runnable {
return input.readNBytes(1);
}
private byte[] readUntilChar(char c) throws IOException {
byte b[] = new byte[256];
int i;
for (i = 0;; i++) {
b[i] = (byte)input.read();
if (b[i] == c)
break;
}
byte r[] = new byte[i];
System.arraycopy(b, 0, r, 0, i);
return r;
}
private void sendLogs(byte PID) throws IOException {
if (PID == LOGS) {
respond(server.cache.content());