inherit "module";
#include <module.h>

// How many seconds should be counted into the 'current bps'
float BPS_SECONDS=5.0;

// How much is read/written at once
int BLOCK_SIZE=1024;

// Max transfer (write) speed
float write_speed=0.0;

// non-throttled ips
array(string) nonthrottled=({});

int do_encapsulate=0;

void debug(mixed ... args) {}


class FakeFD
{
  inherit Stdio.File;
  int bytes_written;
  int nonblocking;

  int query_fd() { return -1; }

  void set_nonblocking(mixed ... args)
  {
    nonblocking=1;
    ::set_nonblocking(@args);
  }

  void set_blocking()
  {
    nonblocking=0;
    ::set_blocking();
  }
    


  int first_write_flag;

  // Fake write functions, it does two things:
  //  o Make sure that no more than BLOCK_SIZE bytes are written at once
  //  o report to add_sent() how much has been written
  int write(string|array(string) data, mixed ... args)
  {
    if(sizeof(args))
      data=sprintf(data,@args);

    int written;
    int len = arrayp(data) ? `+(0,@Array.map(data, strlen)) : strlen(data);

    /* UGLY hack because roxen assumes that it can write
     * small files without blocking
     */
    if(!(first_write_flag++)  && nonblocking && len <= 4096)
      {
	nonblocking=0;
	written = write(data,@args);
	nonblocking=1;
	return written;
      }

    if(len > BLOCK_SIZE)
      {
	if(nonblocking)
	  return write(arrayp(data)?data[0]:data[..BLOCK_SIZE-1]);

	foreach(arrayp(data)?data:data/(float)BLOCK_SIZE,string s)
	  {
	    int tmp=write(s);
	    if(tmp>0) written+=tmp;
	    if(tmp<strlen(s)) return written?written:tmp;
	  }
	return written;
      }else{
	written=::write(data);
      }

    if(written>0)
      {
	add_sent(written);
	bytes_written+=written;
      }
    return written;
  }
}

void set_speed(float s)
{
  write_speed=s;

  if(write_speed<=0.0)
    {
      BLOCK_SIZE=65536;
    }else{
      BLOCK_SIZE=(((int)write_speed)/10) || 1;
    }

  //  werror("write_speed=%f BLOCK_SIZE=%d regular_bps=%f game_bps=%f\n",write_speed,BLOCK_SIZE,regular_bps,game_bps);
}

float regular_bps;
float game_bps;
int next_bps_update;

void update_bps()
{
  if(next_bps_update < time())
    {
      float bps=regular_bps;
      if(game_bps > 0.0)
	{
	  if(string s=Stdio.read_file("/proc/net/ip_masq/udp"))
	    {
	      foreach((s/"\n")[1..], string line)
		{
		  array fields=(replace(line,"\t"," ")/" ")-({""});
		  if(sizeof(fields) < 6) break;
		  if(sscanf(fields[2],"%2x%2x%2x%2x:%4x",int a,int b,int c, int d,int port)==5)
		    {
		      switch(port)
			{
			case 0..1024: continue;
			}
		      bps=game_bps;
		      break;
		    }
		}
	    }
	  next_bps_update=time()+15;
	}else{
	  next_bps_update=time()+365*24*60*60;
	}
      set_speed(bps);
    }
}

void check_variable()
{
  game_bps=query("hubbes_bandwidth_throttle_when_gaming")*1024;
  regular_bps=query("hubbes_bandwidth_throttle")*1024;
  nonthrottled=query("nonthrottled_ips");
  next_bps_update=0;
  update_bps();
}

int bytes=0; // bytes sent so far
float bps=1.0; // Acceptable guess
int base_time=time();
float last_bps_time=-1.0;
float first_bps_time=-1.0;
float last_block_size_change=0.0;

void do_add_sent(int|string b)
{
  if(first_bps_time==-1)
  {
    last_block_size_change=first_bps_time=last_bps_time=time(base_time);
    return;
  }
  update_bps();
  if(stringp(b)) b=strlen(b);
  float t=time(base_time);
  float bpssec=max(0.0,min(BPS_SECONDS, (t-first_bps_time)/3.0))+1.0;

  float bps_frac=1.0-1.0/bpssec;
  bytes+=b;
  debug(3,"\nBPS=%f %f %f %f %d",bps,bps_frac,t-last_bps_time,t,b);
  bps = bps * pow(bps_frac,(float)(t-last_bps_time)) + b/bpssec;
  debug(3," BPS=%f\n",bps);
  last_bps_time=t;
  
  if(write_speed >0.0 && bps > write_speed)
  {
    float wakeup_time=log(write_speed/bps) / log(bps_frac);
    debug(3,"\nSleeping %f seconds.\n",wakeup_time);
    sleep(wakeup_time);

    t=time(base_time);
    float bpssec=max(0.0,min(BPS_SECONDS, (t-first_bps_time)/3.0))+1.0;
    bps = bps * pow(1.0-1.0/bpssec,(float)(t-last_bps_time));
    last_bps_time=t;
  }

}

function global_add_sent;
void add_sent(int bytes)
{
  if(!global_add_sent)
    if(!(global_add_sent=all_constants()->bandwidth_throttle_add_sent))
	add_constant("bandwidth_throttle_add_sent",global_add_sent=do_add_sent);
  global_add_sent(bytes);
}

object encapsulate(object tmp)
{
  string ip=tmp->query_address();

  foreach(nonthrottled,string s)
    if(glob(s, ip))
      return tmp;

  add_sent(20);
  if(write_speed <= 0.0) return tmp;
  object o=FakeFD();
  o->assign(tmp);
  destruct(tmp);
  return o;
}


class FakeProto
{
  array proto;
  mixed realproto;
  void create(array a)
  {
    proto=a;
    realproto=a[-1];
    a[-1]=this_object();
  }

  int unlink()
  {
    if(proto[-1] == this_object())
      {
	proto[-1]=realproto;
	return 1;
      }
    catch {
      for(object o=proto[-1];o;o=o->realproto)
	{
	  if(o->realproto == this_object())
	    {
	      o->realproto=realproto;
	      return 1;
	    }
	}
    };
    return 0;
  }
  object `()(object fd, object conf)
  {
    if( catch {
      fd=encapsulate(fd);
    }) {
      if(unlink())
	{
	  mixed tmp=proto;
	  destruct(this_object());
	  return tmp[-1](fd,tmp[1]);
	}
    }
    return realproto(fd,conf);
  }
}

void create()
{
  defvar("hubbes_bandwidth_throttle",
	 0.0,
	 "outgoing bandwidth",
	 TYPE_FLOAT | VAR_MORE,
	 "This variable allows you to limit the outgoing bandwidth of "
	 "your roxen server in kib/s. A value of zero means no limitation.");
  
  defvar("hubbes_bandwidth_throttle_when_gaming",
	 0.0,
	 "when gaming",
	 TYPE_FLOAT | VAR_MORE,
	 "If you are using linux, masquerading and running your roxen "
	 "on the same machine that does the masquerading, then you can "
	 "use this variable to limit your bandwidth only when somebody "
	 "is playing games. This is detected by cheecking "
	 "/proc/net/ip/masq/udp. A value of zero means that the regular "
	 "bandwidth check will be applied instead. ");

  defvar("nonthrottled_ips",
	 ({}),
	 "Excluded ips",
	 TYPE_STRING_LIST,
	 "A comma separated list of ip addresses globs that will not be  "
	 "throttled. Example: 10.0.1.*, 10.0.2.*");
}

int started;

void start()
{
  if(!started)
    {
      foreach(indices(roxen->portno), object port)
	{
	  array x=roxen->portno[port];
	  if(x[1])
	    FakeProto(x);
	}
      started=1;
    }
  check_variable();
}

void update_global_variables()
{
  if(roxen->configuration_interface_obj)
    destruct(roxen->configuration_interface_obj);
  roxen->configuration_interface()->build_root(roxen->root);

}

array register_module()
{
  return ({ 0, "Hubbes bandwidth throttle module", 
	    "This module allows you to limit the bandwidth of of your roxen. "
	    "It also has a special mode that allows you to limit the "
	    "bandwidth only when you are playing Quake/Half-Life/Unreal. "
	    "(This mode only works if you are running linux and using "
	    "NAT and running roxen on the host that does the "
	    "NAT.) BEWARE: This module affects *all* virtual servers, "
	    "not just the server it was added to!", 0, 1 });
}
