package com.navtools.armi.networking;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketException;
import java.net.ConnectException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.*;

import com.navtools.util.ByteQueue;
import com.navtools.util.MathUtil;
import com.navtools.util.PerformanceMonitor;
import org.apache.log4j.Category;

/**
 *  This class is used for channel based communication with a
 *  server socket.
 */

public class ChannelCommunicator
{

	private int serverPort_ = 0;
	private int listenPort_ = 10001;
	private String serverHost_ = null;
	private ChannelDataServerListener dataListener_ = null;
	private ByteBuffer inBuffer = ByteBuffer.allocate( 4096 );
	private ByteBuffer outBuffer = null;
	private Selector sel_ = null;
	private InetAddress serverAddress_ = null;
	protected SelectionKey serverKey_ = null;
	protected SelectionKey acceptKey_ = null;
	protected Hashtable outBufferHash = new Hashtable();
	protected Hashtable inQueueHash = new Hashtable();
	protected Hashtable messageSizeHash = new Hashtable();
	protected Hashtable addressToKeyHash = new Hashtable();
	protected Hashtable disconnectListeners_ = new Hashtable();

	private static int sendNumber = 0;
	private static int getNumber = 0;

	private boolean isClient = false;

	public static final Category LOG =
	Category.getInstance( ChannelCommunicator.class.getName() );

	/**
	 * This hashtable will be used to store an Integer
	 * representing the channelBuffer number that
	 * will be written to by the channel with the corresponding
	 * SelectionKey.
	 */
	private Hashtable channelKeyHash_ = new Hashtable();
	private int tryCounter = 0;


	public ChannelCommunicator( int listenPort, ChannelDataServerListener dataReader ) throws Exception
	{
		dataListener_ = dataReader;
		if ( listenPort > 0 )
		{
			listenPort_ = listenPort;
		}
		initializeChannel();
	}


	public ChannelCommunicator( InetAddress serverAddress, int serverPort, int listenPort,
	                            ChannelDataServerListener dataReader ) throws Exception
	{
		dataListener_ = dataReader;
		serverAddress_ = serverAddress;
		serverPort_ = serverPort;
		if ( listenPort > 0 )
		{
			listenPort_ = listenPort;
		}
		initializeChannel();
	}


	private void initializeChannel() throws Exception
	{
		// Open a selector
		sel_ = Selector.open();

		try
		{
			if ( serverAddress_ == null )
			{
				/* Create a server socket channel for accepting any connections */
				ServerSocketChannel server = ServerSocketChannel.open();
				server.configureBlocking( false );
				InetAddress ia = InetAddress.getLocalHost();
				InetSocketAddress isa = new InetSocketAddress( ia, listenPort_ );
				LOG.info( "Opening connection to listen on port " + listenPort_ );
                                System.out.println("Opening listening connection");
				server.socket().bind( isa );
				acceptKey_ = server.register( sel_, SelectionKey.OP_ACCEPT );
				/* Done creating server socket channel for accepting */
			}
			else if ( serverAddress_ != null )
			{
				/* if a serverhost has been specified, open client connection to it */
				isClient = true;
				SocketChannel channel = SocketChannel.open();
				channel.configureBlocking( false );
				InetSocketAddress socketAddress = null;

				if ( serverAddress_.equals( InetAddress.getByName( "localhost" ) ) )
				{
					socketAddress = new InetSocketAddress( InetAddress.getLocalHost(), serverPort_ );
				}
				else
				{
					socketAddress = new InetSocketAddress( serverAddress_, serverPort_ );
				}

				LOG.info( "Connection to server at socket address: " + socketAddress );

				channel.connect( socketAddress );
				serverKey_ = channel.register( sel_, SelectionKey.OP_CONNECT | SelectionKey.OP_READ );
				addressToKeyHash.put( ServerAddress.make( serverAddress_, serverPort_ ), serverKey_ );
				LOG.info( "Going to finish connecting to: " + socketAddress );
                                System.out.print("Going to complete connection..");
				while ( ( channel.finishConnect() == false ) && tryCounter < 5 )
				{
					LOG.info( "Have not finished connecting yet..." );
                                        System.out.print(".");
					Thread.sleep( 20 );
					tryCounter++;
				}

				if ( channel.finishConnect() == false )
				{
                                        System.out.println(".failed");
					throw new SocketException( "Failure to connect to server." );
				}
                                else {
                                    System.out.println(".done");
                                }

				LOG.info( "Finished making connection to socket address: " + socketAddress );
			}
			/* Finished opening and connecting to a server */
		}
		catch ( Exception e )
		{
			if ( e.getClass() == SocketException.class || e.getClass() == ConnectException.class )
			{
				throw e;
			}
			else
			{
                          System.out.println("Exception thrown trying to connect: "+e.getClass().getName());
			  LOG.error( "Exception thrown when client connection", e );
			}
		}

		new Thread( toString() )
		{
			public void run()
			{
				while ( true )
				{
					//System.out.println("Iterating");
					try
					{

						// Loops through all the buffers and set their keys to the
						// appropriate interested ops.  Otherwise, we may select something
						// when we don't need one, and we would like this select to blcok
						// if nothing is needed.
						boolean leftoverData = false;
						Enumeration keys = outBufferHash.keys();
						while ( keys.hasMoreElements() )
						{
							SelectionKey currKey = (SelectionKey) keys.nextElement();


							ByteQueue outQueue = (ByteQueue) outBufferHash.get( currKey );
							if ( outQueue != null && outQueue.remaining() > 0 )
							{
								try
								{
									writeMessage( currKey );
									if ( outQueue.remaining() > 0 )
									{
										leftoverData = true;
									}
								}
								catch ( Exception e )
								{
									LOG.error( "Exception thrown trying to write to channel: " + e.getMessage() );
								}
							}


							//currKey.interestOps(SelectionKey.OP_READ|SelectionKey.OP_CONNECT);
						}

						if ( leftoverData )
						{
							//System.out.println("Selecting now");
							sel_.selectNow();
						}
						else
						{
							//System.out.println("Selecting with block");
							int numSel = sel_.select();
							//System.out.println("Selected "+numSel);
						}
						Set readyKeys = sel_.selectedKeys();
						Iterator it = readyKeys.iterator();

						while ( it.hasNext() )
						{
							SelectionKey key = (SelectionKey) it.next();

							if ( key.isAcceptable() )
							{
								//System.out.println("Accepting a connection");
								ServerSocketChannel ssc = (ServerSocketChannel) key.channel();
								SocketChannel newChannel = (SocketChannel) ssc.accept();
								newChannel.configureBlocking( false );
								SelectionKey anotherKey = newChannel.register( sel_, SelectionKey.OP_READ );
								LOG.debug( "Accepted a new client connection." );
							}

							if ( key.isReadable() )
							{
								readMessage( key );
							}

							it.remove();
						}

						Thread.sleep( 2 );
					}
					catch ( Exception e )
					{
						LOG.error( "Exception thrown during selection process", e );
					}
				}
			}
		}.start();
	}


	private void readMessage( SelectionKey key )
	{

		if ( key == null )
		{
			return;
		}
		try
		{
			SocketChannel sc = (SocketChannel) key.channel();

			Integer msgSize = (Integer) messageSizeHash.get( key );
			if ( msgSize == null )
			{
				msgSize = new Integer( 0 );
			}

			ByteQueue inQueue = (ByteQueue) inQueueHash.get( key );
			if ( inQueue == null )
			{
				inQueue = new ByteQueue();
				inQueueHash.put( key, inQueue );
			}

			if ( sc == null || !sc.isConnected() )
			{
				return;
			}
			synchronized ( sc )
			{
				inBuffer.clear();
				int numBytes = sc.read( inBuffer );
				if ( numBytes > 0 )
				{

					if ( PerformanceMonitor.instance().isThroughputEnabled() )
					{
						PerformanceMonitor.instance().addIncomingBytes( new Long( key.hashCode() + "" ),
						                                                System.currentTimeMillis(),
						                                                numBytes );
					}
					inBuffer.flip();
					byte[] newBytes = new byte[numBytes];

					inBuffer.get( newBytes );
					LOG.debug( "ACTUALLY read " + numBytes + " from the channel" );
					LOG.debug( "Number of bytes just pulled from the raw socket: " + newBytes.length );
					LOG.debug( "Number of bytes sitting in the in queue BEFORE  pushing new raw bytes: " + inQueue.remaining() );
					inQueue.push( newBytes );
					LOG.debug( "Number of bytes sitting in the in queue AFTER pushing new raw bytes: " + inQueue.remaining() );


					boolean keepParsing = true;
					while ( keepParsing )
					{
						//check if message size is 0, and if so, get next size
						if ( msgSize.intValue() == 0 )
						{
							if ( inQueue.remaining() > 3 )
							{
								byte[] sizeBytes = new byte[4];
								inQueue.pop( sizeBytes, 0, 4 );
								msgSize = new Integer( MathUtil.asInt( sizeBytes, 0 ) );
								LOG.debug( "(" + ++getNumber + ") message size will be: " + msgSize );
							}
							else
							{
								//we need a new message size, but there aren't enough
								//bytes in the queue to get one - break out till there
								//are more
								break;
							}
						}

						LOG.debug( "Current inQueue size is: " + inQueue.remaining() );
						LOG.debug( "msgSize is: " + msgSize.intValue() );

						// if all the bytes for the next message are in the queue
						// process the message and reset the msgSize
						if ( inQueue.remaining() >= msgSize.intValue() )
						{
							byte[] messageBytes = new byte[msgSize.intValue()];
							inQueue.pop( messageBytes, 0, msgSize.intValue() );
							LOG.debug( "Calling listener for message number: " + getNumber );
							dataListener_.newData( key, messageBytes );
							msgSize = new Integer( 0 );
							messageSizeHash.put( key, msgSize );
						}
						else
						{
							messageSizeHash.put( key, msgSize );
							break;
						}
					}

				}
			}
		}
		catch ( java.io.IOException e )
		{
			LOG.debug( "This channel is no longer valid" );
			key.cancel();
			notifyDisconnect( key, e );
			System.out.println( "IOException in readMessage" );
		}
		catch ( Throwable t )
		{
			System.err.println( t.getMessage() );
			t.printStackTrace();
		}
	}


	public SelectionKey getKeyForAddress( ServerAddress address )
	{
		SelectionKey aKey = (SelectionKey) addressToKeyHash.get( address );
		if ( aKey == null )
		{
			aKey = connectToServer( address );
		}
		return aKey;
	}


	public void addDisconnectionListener( DisconnectionListener listener, SelectionKey listenConn )
	{
		LOG.debug( "DisconnectionListener being added" );
		List listeners = (List) disconnectListeners_.get( listenConn );
		if ( listeners == null )
		{
			listeners = new ArrayList();
			disconnectListeners_.put( listenConn, listeners );
		}
		listeners.add( listener );
	}


	private SelectionKey connectToServer( ServerAddress address )
	{
		/* if a serverhost has been specified, open client connection to it */
		try
		{
			SocketChannel channel = SocketChannel.open();
			channel.configureBlocking( false );
			InetSocketAddress socketAddress = null;
			if ( address.getIP().equals( InetAddress.getByName( "localhost" ) ) )
			{
				socketAddress = new InetSocketAddress( InetAddress.getLocalHost(), serverPort_ );
			}
			else
			{
				socketAddress = new InetSocketAddress( address.getIP(), serverPort_ );
			}

			channel.connect( socketAddress );
			SelectionKey aServerKey = channel.register( sel_, SelectionKey.OP_CONNECT | SelectionKey.OP_READ );
			addressToKeyHash.put( address, aServerKey );
			channel.finishConnect();
		}
		catch ( Exception e )
		{
			LOG.error( "problem trying to connect to additional server: ", e );
		}

		return serverKey_;
	}


	private void writeMessage( SelectionKey key )
	{
		//System.out.println("writeMessage()");
		if ( key == null )
		{
			return;
		}

		ByteQueue outQueue = null;
		SocketChannel sc = null;


		try
		{
			sc = (SocketChannel) key.channel();
			outQueue = (ByteQueue) outBufferHash.get( key );

			if ( outQueue != null )
			{
				if ( outQueue.remaining() > 0 )
				{
					LOG.debug( "Going to write to the channel.  The outBuffer has this many bytes remaining " + outQueue.remaining() );
					// Call method on ByteQueue that will keep writing to channel until buffer
					// is empty or channel won't take anymore.
					int bytesWrote = outQueue.pop( sc );
					LOG.debug( "ACTUALLY wrote " + bytesWrote + " to the channel" );
				}
			}
		}
		catch ( java.io.IOException e )
		{
			LOG.debug( "This channel is no longer valid" );
			key.cancel();
			outQueue.pop();
			notifyDisconnect( key, e );
			System.out.println( "IOMessage in writeMessage" );
		}
		catch ( Throwable t )
		{
			System.err.println( t.getMessage() );
			t.printStackTrace();
		}
	}


	private void notifyDisconnect( SelectionKey key, Exception reason )
	{
		LOG.debug( "Notifying of disconnect by calling disconnected listeners" );
		outBufferHash.remove( key );
		inQueueHash.remove( key );
		messageSizeHash.remove( key );

		List listeners = (List) disconnectListeners_.get( key );

		if ( listeners != null )
		{
			Iterator iter = listeners.iterator();
			while ( iter.hasNext() )
			{
				( (DisconnectionListener) iter.next() ).disconnected( reason );
			}
		}
	}


	public void write( SelectionKey key, byte[] data )
	{
		if ( key == null || key.isValid() == false )
		{
			return;
		}

		if ( PerformanceMonitor.instance().isThroughputEnabled() )
		{
			PerformanceMonitor.instance().addOutgoingBytes( new Long( key.hashCode() + "" ),
			                                                System.currentTimeMillis(),
			                                                data.length );
		}

		ByteQueue outQueue = (ByteQueue) outBufferHash.get( key );
		if ( outQueue == null )
		{
			outQueue = new ByteQueue();
			outBufferHash.put( key, outQueue );
		}

		byte[] sizeArray = MathUtil.asBytes( data.length );

		synchronized ( outQueue )
		{
			LOG.debug( ++sendNumber + " pushing a message out of size:" + data.length );
			outQueue.push( sizeArray );
			outQueue.push( data );
			LOG.debug(
			"After pushing outbound data to the outQueue, the outQueue has this many bytes in it: " + outQueue.remaining() );
		}
		// If the select is currently blocked it will need to ba awoke so it will
		// know that it has something to write
		sel_.wakeup();
		LOG.debug( "Called \"wakeup\" on the selector" );
	}


	protected Selector getSelector()
	{
		return sel_;
	}
}
