/*
 * $Id: addrcount.c,v 1.2 2001/01/16 03:38:46 wessels Exp wessels $
 * 
 * Counts octets/bytes out/in for IP addresses on a subnet.
 * 
 * Uses libpcap
 */

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <syslog.h>
#include <strings.h>
#include <pcap.h>
#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <sys/signal.h>
#include <sys/wait.h>

#include <sys/socket.h>

#include <net/if.h>
#include <netinet/in.h>
#include <netinet/if_ether.h>
#include <netinet/in_systm.h>
#include <netinet/ip.h>
#include <netinet/ip_var.h>
#include <netinet/udp.h>
#include <netinet/udp_var.h>
#include <netinet/tcp.h>
#include <netinet/tcpip.h>

#include <arpa/inet.h>


#define SNAPLEN 68
#define ADDRCOUNT_PORT 2346
#define MAXNADDRS 256

struct _addrstat {
    unsigned int pktin;
    unsigned int pktout;
    unsigned int octin;
    unsigned int octout;
};

static struct in_addr ip_min;
static int Listen = -1;
static struct _addrstat Addrstats[MAXNADDRS];
static int naddrs = 0;

void
packet_handle(unsigned char *user, const struct pcap_pkthdr * h, const unsigned char *p)
{
    unsigned int length = h->len;
    struct ip *ip;
    const struct ether_header *ep = (struct ether_header *) p;
    unsigned short ethertype = ntohs(ep->ether_type);
    unsigned int dst_index;
    unsigned int src_index;
    int ok = 0;

    if (ETHERTYPE_IP != ethertype)
	return;
    p += sizeof(*ep);
    length -= sizeof(*ep);

    /* L3 */
    ip = (struct ip *) p;
    if (length < sizeof(*ip)) {
	syslog(LOG_WARNING, "truncated IP packet (%d)", length);
	return;
    }
    dst_index = ntohl(ip->ip_dst.s_addr) - ip_min.s_addr;
    src_index = ntohl(ip->ip_src.s_addr) - ip_min.s_addr;
    if (dst_index < naddrs) {
	Addrstats[dst_index].pktin++;
	Addrstats[dst_index].octin += length;
	ok++;
    }
    if (src_index < naddrs) {
	Addrstats[src_index].pktout++;
	Addrstats[src_index].octout += length;
	ok++;
    }
    if (!ok) {
	syslog(LOG_WARNING, "IP packet not to/from here (src=%08lx,dst=%08lx)",
	    ntohl(ip->ip_src.s_addr),
	    ntohl(ip->ip_dst.s_addr));
	return;
    }
}

static void
check_connect(void)
{
    unsigned int i;
    char buf[128];
    int s;
    int flags;
    int dummy;
    s = accept(Listen, NULL, NULL);
    if (s < 0)
	return;
    if (fork()) {
	close(s);
	return;
    }
    if ((flags = fcntl(s, F_GETFL, dummy)) < 0) {
	syslog(LOG_ERR, "fcntl F_GETFL: %s\n", strerror(errno));
	close(s);
	_exit(1);
    }
    if (fcntl(s, F_SETFL, flags & (~O_NONBLOCK)) < 0) {
	syslog(LOG_ERR, "fcntl F_SETFL: %s\n", strerror(errno));
	close(s);
	_exit(1);
    }
    for (i = 0; i < naddrs; i++) {
	struct in_addr a;
	if (Addrstats[i].pktin)
	    (void) 0;
	else if (Addrstats[i].pktout)
	    (void) 0;
	else if (Addrstats[i].octin)
	    (void) 0;
	else if (Addrstats[i].octout)
	    (void) 0;
	else
	    continue;
	a.s_addr = htonl(ip_min.s_addr + i);
	snprintf(buf, 128, "%s %u %u %u %u\n",
	    inet_ntoa(a),
	    Addrstats[i].pktin,
	    Addrstats[i].pktout,
	    Addrstats[i].octin,
	    Addrstats[i].octout);
	if (write(s, buf, strlen(buf)) < 0)
	    syslog(LOG_ERR, "write: %s\n", strerror(errno));
    }
    write(s, "EOF\n", 4);
    close(s);
    _exit(0);
}

static int
open_server(void)
{
    int s;
    struct sockaddr_in L;
    int flags;
    int dummy;
    s = socket(PF_INET, SOCK_STREAM, 0);
    if (s < 0) {
	syslog(LOG_ERR, "socket: %s\n", strerror(errno));
	return 1;
    }
    memset(&L, '\0', sizeof(L));
    L.sin_family = AF_INET;
    L.sin_addr.s_addr = inet_addr("127.0.0.1");
    L.sin_port = htons(ADDRCOUNT_PORT);
    dummy = 1;
    setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (char *) &dummy, sizeof(dummy));
    if (bind(s, (struct sockaddr *) & L, sizeof(L)) < 0)
	syslog(LOG_WARNING, "bind: %s\n", strerror(errno));
    if ((flags = fcntl(s, F_GETFL, dummy)) < 0) {
	syslog(LOG_ERR, "fcntl F_GETFL: %s\n", strerror(errno));
	close(s);
	return -1;
    }
    if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
	syslog(LOG_ERR, "fcntl F_SETFL: %s\n", strerror(errno));
	close(s);
	return -1;
    }
    listen(s, 1);
    return s;
}

void
reap(int sig)
{
    int status;
    pid_t p;
    while ((p = waitpid(-1, &status, WNOHANG)) > 0);
    signal(SIGCHLD, reap);
}

int
main(int argc, char *argv[])
{
    struct bpf_program bpf;
    pcap_t *pc;
    char errmsg[PCAP_ERRBUF_SIZE];
    char *dev;
    unsigned int netmask;
    unsigned int localnet;
    pid_t pid;
    if (4 != argc) {
	fprintf(stderr, "usage: %s device first_ip_addr naddrs\n", argv[0]);
	return 1;
    }
    dev = strdup(argv[1]);
    ip_min.s_addr = ntohl(inet_addr(argv[2]));
    naddrs = atoi(argv[3]);
    assert(naddrs <= MAXNADDRS);
    openlog("addrcount", 0, LOG_DAEMON);
    if ((pid = fork()) < 0)
	syslog(LOG_ERR, "fork failed: %s\n", strerror(errno));
    else if (pid > 0)
	exit(0);
    if (setsid() < 0)
	syslog(LOG_ERR, "setsid failed: %s\n", strerror(errno));
    signal(SIGCHLD, reap);
    Listen = open_server();
    pc = pcap_open_live(dev, SNAPLEN, 0, 1000, errmsg);
    if (NULL == pc) {
	syslog(LOG_ERR, "%s: %s\n", dev, errmsg);
	return 1;
    }
    if (pcap_lookupnet(dev, &localnet, &netmask, errmsg) < 0) {
	localnet = 0;
	netmask = 0;
	syslog(LOG_WARNING, "%s", errmsg);
    }
    setuid(getuid());
    pcap_compile(pc, &bpf, 0, 1, netmask);
    if (pcap_setfilter(pc, &bpf) < 0) {
	syslog(LOG_ERR, "%s", pcap_geterr(pc));
	return 1;
    }
    syslog(LOG_CRIT, "addrcount starting\n");
    while (1) {
	if (pcap_dispatch(pc, 0, packet_handle, NULL) < 0) {
	    syslog(LOG_ERR, "pcap_dispatch: %s\n", pcap_geterr(pc));
	    break;
	}
	check_connect();
    }
    pcap_close(pc);
    syslog(LOG_CRIT, "addrcount exiting\n");
    return 0;
}
