/*
 * $Id: ARMIProxyFactory.java,v 1.2 2003/09/24 14:55:50 cactushack76 Exp $
 */

package com.navtools.armi;

/**
 * Asynchronous RMI Proxy factory
 *
 * Use this to build proxies to remote objects. The remote object
 * must implement an interface, said interface is the argument
 * used to build the proxy object.  The only requirements on the interface
 * is that the first argument of every method must be a long msgID, and
 * all other arguments and return values must be either a long, int, short,
 * byte, String, Long, Integer, Short, Byte, or a class that implements
 * DataStreamable or is shared.  This is similar to the restriction on regular RMI
 * objects that all arguments and return values be serializable or Remote.
 * Since this is an asynchronous proxy, the return value will not come back
 * immediately.  You can either wait for it or register a block of code
 * to be executed when the return value comes back.
 */
//import java.rmi.UnknownHostException;

import java.io.IOException;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.InetAddress;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.nio.channels.SelectionKey;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import com.navtools.armi.networking.*;
import com.navtools.serialization.DataStreamableUtil;
import com.navtools.util.MathUtil;
import com.navtools.util.Pair;
import com.navtools.util.PerformanceMonitor;
import com.navtools.util.Triplet;
import org.apache.log4j.Category;
import org.apache.log4j.NDC;

public class ARMIProxyFactory
implements InvocationHandler
{
    public static final Category LOG =
    Category.getInstance( ARMIProxyFactory.class.getName() );


    protected ARMIProxyFactory( MessengerID address, int instanceID )
    {
        address_ = address;
        instanceID_ = instanceID;
        t_ = ClassAndMethodTable.instance();
    }


    static
    {
        RMIResponseMessage.register();

        RMIMessage.register();
    }


    protected static void sharedPostInit()
    {
        imq_.addMessageListener( RMIMessageListener.instance() );
        imq_.addMessageListener( RMIResponseMessageListener.instance() );
    }


    /**
     * Call this method once, typically in main, to establish the queue and
     * ClassAndMethodTable to use for message translation and transmission.
     * This is the method to call for programs which are clients to the
     * registry server.
     */
    public static void initialize( InetAddress registryAddress )
    throws SocketException, UnknownHostException, Exception
    {
        /* imq_ = StandardUDPIncomingMessageQueue.establishInstance(); */
        imq_ = StandardUDPIncomingMessageQueue.establishInstance( registryAddress, ARMIProxyFactory.DEFAULT_PORT );

        ServerAddress registryServerAddress =
        ServerAddress.make( registryAddress, DEFAULT_PORT );

        ClassAndMethodTable.establishInstance( registryServerAddress );

        sharedPostInit();
    }


    /**
     * Call this method once, typically in main, to establish the queue and
     * ClassAndMethodTable to use for message translation and transmission.
     * This is the method to call for the program which also serves as the
     * registry server.
     */
    public static void initializeAsRegistryServer()
    throws SocketException, UnknownHostException, Exception
    {
        imq_ = StandardUDPIncomingMessageQueue.establishInstance( DEFAULT_PORT );
        sharedPostInit();
    }


    public Object invoke( Object proxy, Method m, Object[] args )
    throws Throwable
    {
        RemoteReturnValue retVal = new RemoteReturnValue( Message.getNextID() );
        try
        {
            if ( ChannelMessenger.MSG_PERF.isDebugEnabled() )
            {
                ChannelMessenger.MSG_PERF.debug( "invoked " + m.getName() + " on " +
                                                 address_ );
            }

            if ( m.getName().equals( DisconnectionListened.DISCONNECT_CALLBACK )
                 &&
                 args[0] instanceof DisconnectionListener )
            {
                addDisconnectionListener( proxyKey_, (DisconnectionListener) args[0] );
                return null;
            }
            else
            {
                if ( m.getName().equals( "hashCode" ) )
                {
                    return new Integer( proxyHashCode( proxy ) );
                }
                else if ( m.getName().equals("shutdown")) {
                    imq_.patientlyCloseAllComms();
                    return null; 
                }
                else
                {
                    if ( m.getName().equals( "equals" ) )
                    {
                        return new Boolean( proxyEquals( proxy, args[0] ) );
                    }
                }    
                
            } 


            long msgBuildStart = System.currentTimeMillis(); // decide where to set this... and maybe check if we need to do it

            Integer methodID = t_.getID( m );
            //preconditions
            if ( methodID == null )
            {
                throw new Exception( "method doesn't exist on proxy" );
            }
            //if( args.length < 1 || (!(args[0] instanceof Long)) ) throw new Exception("First proxy method must have a first argument of type Long");

            //when you update this code that builds the RMIMessage, update
            //the sister code in RMIMessageListener.processMessage
            /* Message msg = new RMIMessage(((Long)args[0]).longValue(),
                                          instanceID_,
                                          t_.getID(m).intValue());
            */
            Message msg = new RMIMessage( retVal.getMessageID(),
                                          instanceID_,
                                          methodID.intValue() );
            writeObjsToMessage( msg, args, m.getParameterTypes() );

            if ( ChannelMessenger.MSG_PERF.isDebugEnabled() )
            {
                ChannelMessenger.MSG_PERF.debug( "constructed message id " + msg.getMessageID() + " for " +
                                                 m.getName() + " on " + address_ );
            }

            if ( PerformanceMonitor.instance().isMsgBuildTimeEnabled() )
            {
                PerformanceMonitor.instance().addMsgBuildTime( address_.getServerIDAsLong(),
                                                               System.currentTimeMillis() - msgBuildStart );
            }

            if ( LOG.isDebugEnabled() )
            {
                NDC.push( "Sending " + m.getName() + " to " + address_ );
            }
            if ( LOG.isDebugEnabled() )
            {
                LOG.debug( "About to send message" );
            }
            if ( LOG.isDebugEnabled() )
            {
                LOG.debug( "***Msg ID: " + msg.getMessageID() + " sent***" );
            }
            omq_.send( msg, proxyKey_ );

            if ( LOG.isDebugEnabled() )
            {
                LOG.debug( "Sent message" );
            }

            if ( ChannelMessenger.MSG_PERF.isDebugEnabled() )
            {
                ChannelMessenger.MSG_PERF.debug( "sent message id " + msg.getMessageID() );
            }

        }
        catch ( Exception e )
        {
            LOG.error( e.getMessage(), e );
            throw e;
        }
        finally
        {
            if ( LOG.isDebugEnabled() )
            {
                NDC.pop();
            }
        }

        return retVal;
    }


    public static void writeObjsToMessage( Message msg, Object[] args,
                                           Class[] argTypes )
    {
        //argFlags is a set of bits; three bits for each argument.
        //The first bit for each argument tells if the arg is null
        //The second bit tells if the arg is a subclass of the actual arg
        //type
        //The third bit tells if the object is a remote object
        int arglength = args == null ? 0 : args.length;
        int numBits = arglength * NUM_OFFSETS;
        BitSet argFlags = new BitSet( numBits );
        for ( int i = 0; i < arglength; ++i )
        {
            //set null flag; no need to check if is remote or subclass :)
            if ( args[i] == null )
            {
                argFlags.set( i * NUM_OFFSETS + IS_NULL_OFFSET );
            }
            else  //arg is not null
            {
                //is it a subclass of the expected type?
                Class expectedParamType = args[i].getClass();
                //some non-matching classes are not subclasses, i.e.
                //Long.class instead of long.class, etc.
                Class alternateParamType = (Class)
                getDefaultParamClassMap().get( expectedParamType );

                //set a bit to indicate if it's a shared object being
                //passed in
                if ( isShared( args[i] ) )
                {
                    argFlags.set( i * NUM_OFFSETS + IS_REMOTE_OFFSET );
                    //if it's not a shared object, check to see if
                    //it's a subclass.  If it doesn't match the real
                    //type or the alternate type, then it must be a
                    //subclass
                }
                else
                {
                    if ( !( expectedParamType.equals( argTypes[i] ) ) &&
                         ( alternateParamType == null ||
                           !( alternateParamType.equals( argTypes[i] ) ) ) )
                    {
                        if ( LOG.isDebugEnabled() )
                        {
                            LOG.debug( args[i].getClass() + " looks like a subtype of " + argTypes[i] );
                        }
                        argFlags.set( i * NUM_OFFSETS + IS_SUBCLASS_OFFSET );
                    }
                }
            }
        }

        msg.add( MathUtil.bitSetToBytes( argFlags, numBits ) );

        for ( int i = 0; i < arglength; ++i )
        {
            //if arg is not null, write it to the stream
            if ( !argFlags.get( i * NUM_OFFSETS + IS_NULL_OFFSET ) )
            {
                //if arg is shared, the IP address and ID is enough to find it
                //again
                if ( argFlags.get( i * NUM_OFFSETS + IS_REMOTE_OFFSET ) )
                {
                    Pair proxyInfo = getProxyInfo( args[i] );
                    if ( LOG.isDebugEnabled() )
                    {
                        LOG.debug( "Sending " + argTypes[i] + " proxy on " +
                                   proxyInfo + " as arg" );
                    }
                    msg.add( proxyInfo.getFirst() );
                    msg.add( proxyInfo.getSecond() );
                }
                else
                {
                    if ( argFlags.get( i * NUM_OFFSETS + IS_SUBCLASS_OFFSET ) )
                    {
                        msg.add( ClassAndMethodTable.instance().
                                 getID( args[i].getClass() ) );
                    }
                    msg.add( args[i] );
                }
            }
        }
    }


    public static Object[] readObjsFromMessage( Message msg, Class[] argTypes )
    throws IOException
    {
        final Object[] args = new Object[argTypes.length];

        //read bitvector telling which args (if any) are null
        int numArgFlagBits =
        args.length * ARMIProxyFactory.NUM_OFFSETS;
        byte[] argFlagBytes = new byte[( numArgFlagBits + 7 ) / 8];

        if ( LOG.isDebugEnabled() )
        {
            LOG.debug( "about to read arg flags" );
        }
        int numRead = msg.getDataInputStream().read( argFlagBytes );
        if ( LOG.isDebugEnabled() )
        {
            LOG.debug( "Read " + numRead + " bytes of arg flags" );
        }
        BitSet argFlags =
        MathUtil.bytesToBitSet( argFlagBytes, numArgFlagBits );

        for ( int i = 0; i < numArgFlagBits; ++i )
        {
            if ( LOG.isDebugEnabled() )
            {
                LOG.debug( "" + argFlags.get( i ) );
            }
        }

        //for each arg type
        //    pull arg from msg
        //    put arg in Object[] array
        for ( int i = 0; i < argTypes.length; ++i )
        {
            //if arg is not null, read it from the stream
            if ( !argFlags.get( i * ARMIProxyFactory.NUM_OFFSETS +
                                ARMIProxyFactory.IS_NULL_OFFSET ) )
            {
                //if the arg is a remote object
                if ( argFlags.get( i * ARMIProxyFactory.NUM_OFFSETS +
                                   ARMIProxyFactory.IS_REMOTE_OFFSET ) )
                {
                    MessengerID addy = (MessengerID) DataStreamableUtil.
                    readFrom( MessengerID.class,
                              msg.getDataInputStream() );

                    int instanceID = msg.getDataInputStream().readInt();

                    args[i] = ARMIProxyFactory.newInstance( argTypes[i], addy,
                                                            instanceID, msg.getClientKey() );
                }
                else //if not remote
                {
                    Class argType;
                    if ( argFlags.get( i * ARMIProxyFactory.NUM_OFFSETS +
                                       ARMIProxyFactory.IS_SUBCLASS_OFFSET ) )
                    {
                        int classID = msg.getDataInputStream().readInt();
                        argType = ClassAndMethodTable.instance().
                        getClass( new Integer( classID ) );
                    }
                    else
                    {
                        argType = argTypes[i];
                    }

                    args[i] = DataStreamableUtil.readFrom( argType,
                                                           msg.getDataInputStream() );
                }
            }
        }

        return args;
    }


    public void addDisconnectionListener( DisconnectionListener dcb )
    {
        omq_.getCorrespondingIncomingMessageQueue().getMessenger().
        addDisconnectionListener( address_, dcb );
    }

    public void addDisconnectionListener( SelectionKey key,
                                          DisconnectionListener dcb )
    {
        omq_.getCorrespondingIncomingMessageQueue().getMessenger().
        addDisconnectionListener( key, dcb );
    }


    public static boolean isShared( Object obj )
    {
        //if we either have shared this object out, or if it is a proxy class
        //then it is shared
        return Proxy.isProxyClass( obj.getClass() ) ||
               objToProxyInfoMap_.containsKey( obj );
    }


    public static Pair getProxyInfo( Object obj )
    {
        if ( Proxy.isProxyClass( obj.getClass() ) )
        {
            ARMIProxyFactory proxyFactory =
            (ARMIProxyFactory) Proxy.getInvocationHandler( obj );
            return new Pair( proxyFactory.address_,
                             new Integer( proxyFactory.instanceID_ ) );
        }

        return (Pair) objToProxyInfoMap_.get( obj );
    }


    public static void mapObjectToProxyInfo( Object obj, MessengerID address,
                                             int instanceID )
    {
        objToProxyInfoMap_.put( obj,
                                new Pair( address, new Integer( instanceID ) ) );
    }


    public static Map getDefaultParamClassMap()
    {
        if ( defaultParamClassMap_ == null )
        {
            defaultParamClassMap_ = new HashMap();
            defaultParamClassMap_.put( Long.class, long.class );
            defaultParamClassMap_.put( Integer.class, int.class );
            defaultParamClassMap_.put( Short.class, short.class );
            defaultParamClassMap_.put( Byte.class, byte.class );
            defaultParamClassMap_.put( Boolean.class, boolean.class );
        }

        return defaultParamClassMap_;
    }


    public static OutgoingMessageQueue getOutgoingMessageQueue()
    {
        return omq_;
    }


    public int proxyHashCode( Object proxy )
    {
        return getProxyInfo( proxy ).hashCode();
    }


    public boolean proxyEquals( Object lhs, Object rhs )
    {
        return getProxyInfo( lhs ).equals( getProxyInfo( rhs ) );
    }


    public static void expectAckFor( long messageID )
    {
        AckMessageListener.instance().expectAckFor( messageID );
    }


    public static void expectAckFor( Long messageID )
    {
        expectAckFor( messageID.longValue() );
    }


    public static Object waitForReturn( long messageID, Class returnType )
    {
        return RMIResponseMessageListener.instance().waitForReturn( messageID,
                                                                    returnType );
    }


    public static ServerAddress getLocalAddress()
    {
        if ( localAddress_ == null )
        {
            try
            {
                localAddress_ =
                ServerAddress.make( getMessenger().getLocalAddress(),
                                    imq_.getLocalPort() );
            }
            catch ( Exception e )
            {
                LOG.error( e.getMessage(), e );
            }
        }
        return localAddress_;
    }


    public static ARMIProxyFactory forObject( MessengerID address,
                                              int instanceID, SelectionKey proxyKey )
    {
        ARMIProxyFactory retval = null;
        Triplet key = new Triplet( new Long( address.getServerID() ),
                                   new Integer( instanceID ), proxyKey );

        //check to see if factory already exists for this address
        //if so, just return it
        if ( ( retval = (ARMIProxyFactory) instanceMap_.get( key ) ) == null )
        {
            //if not, build one and store it
            retval = new ARMIProxyFactory( address, instanceID );
            retval.setProxyKey( proxyKey );
            instanceMap_.put( key, retval );
        }

        return retval;
    }


    /**
     * Use this method to get a proxy for the singleton instance of the
     * interface given on the server at the address given.  You will
     * have to cast this returned object to be the type of intrface.
     */
    static public Object newInstance( Class intrface, InetAddress inetAddress )
    {
        return newInstance( intrface, inetAddress, 0 );
    }


    /**
     * Use this method to get a proxy for the instance of the
     * interface given on the server at the address given.  You will
     * have to cast this returned object to be the type of intrface.
     */
    static public Object newInstance( Class intrface, InetAddress inetAddress,
                                      int instanceID )
    {
        ServerAddress address = ServerAddress.make( inetAddress, DEFAULT_PORT );

        return newInstance( intrface, address, instanceID );
    }


    /**
     * Use this method to get a proxy for the singleton instance of the
     * interface given on the server at the address given.  You will
     * have to cast this returned object to be the type of intrface.
     */
    static public Object newInstance( Class intrface, ServerAddress address )
    {
        return newInstance( intrface, address, 0 );
    }


    /**
     * Use this method to get a proxy for the singleton instance of the
     * interface given on the server at the address given.  You will
     * have to cast this returned object to be the type of intrface.
     */
    static public Object newInstance( Class intrface,
                                      ServerAddress serverAddress,
                                      int instanceID ) 
    {
        SelectionKey proxyKey =
        StandardUDPIncomingMessageQueue.instance().getSocket().getKeyForAddress( serverAddress );
        return newInstance( intrface, new MessengerID( serverAddress, 0 ),
                            instanceID, proxyKey );
    }


    /**
     * Use this method to get a proxy for the singleton instance of the
     * interface given on the server at the address given.  You will
     * have to cast this returned object to be the type of intrface.
     */
    static public Object newInstance( Class intrface,
                                      MessengerID address,
                                      int instanceID, SelectionKey proxyKey )
    {
        //build key to cached proxies
        Triplet key = new Triplet( intrface, new Integer( instanceID ), proxyKey );

        //get cached proxy
        Object retval = interfaceAddressToProxy_.get( key );

        if ( LOG.isDebugEnabled() )
        {
            LOG.debug( "Getting proxy for " + intrface + ", id: " + instanceID +
                       " on " + address );
        }

        //if proxy not cached, cache it now
        if ( retval == null )
        {
            if ( LOG.isDebugEnabled() )
            {
                LOG.debug( "Proxy not in cache, caching now" );
            }
            retval = Proxy.newProxyInstance( intrface.getClassLoader(),
                                             new Class[]{intrface},
                                             ARMIProxyFactory.
                                             forObject( address, instanceID, proxyKey ) );

            interfaceAddressToProxy_.put( key, retval );
        }
        else
        {
            if ( LOG.isDebugEnabled() )
            {
                LOG.debug( "Proxy already in cache" );
            }
        }

        return retval;
    }


    public static UDPMessengerInterface getMessenger()
    {
        return imq_.getMessenger();
    }


    public void setProxyKey( SelectionKey proxyKey )
    {
        proxyKey_ = proxyKey;
    }


    protected int instanceID_;
    protected MessengerID address_;
    protected SelectionKey proxyKey_ = null;

    protected static Map objToProxyInfoMap_ = new HashMap();
    protected static Map interfaceAddressToProxy_ = new HashMap();
    protected static OutgoingMessageQueue omq_ =
    StandardUDPOutgoingMessageQueue.instance();
    protected static StandardUDPIncomingMessageQueue imq_;
    protected static ClassAndMethodTable t_;
    public static int DEFAULT_PORT = 0xB0B;
    protected static ServerAddress localAddress_;
    protected static Map defaultParamClassMap_;

    protected static Map instanceMap_ =
    Collections.synchronizedMap( new HashMap() );

    public static int IS_NULL_OFFSET = 0;
    public static int IS_SUBCLASS_OFFSET = 1;
    public static int IS_REMOTE_OFFSET = 2;
    public static int NUM_OFFSETS = 3;

    //debugging code follows

    //client side main for testing UDP RMI
    public static void main( String[] args ) throws Exception
    {
        /** TODO replace this with a com.navtools.networking.armi.networking.test interface that doesn't change,
         *  then uncomment
         InetAddress inetAddress = InetAddress.getByName("localhost");
         ARMIProxyFactory.initialize(inetAddress);

         ServerInterface i = (ServerInterface)ARMIProxyFactory.newInstance(ServerInterface.class, inetAddress);

         ObjectID id1 = ObjectID.make(1);
         ObjectID id2 = ObjectID.make(2);
         ClientInterface client = null;
         TextMessage textMessage = TextMessage.make();
         ObjectInformation info = ObjectInformation.make();

         i.logoff( Message.getNextID(), id1 );
         i.requestLogon( Message.getNextID(), "Bobby", "Wurp" );
         i.requestServerList(Message.getNextID());
         i.requestInventory(Message.getNextID(), id1);
         i.init(Message.getNextID(), client, getLocalAddress());
         i.shutdown(Message.getNextID());
         i.updateObject(Message.getNextID(), id1, info );
         i.pickUp(Message.getNextID(), id1 );
         i.drop(Message.getNextID(), id1 );
         i.sendTextMessage(Message.getNextID(),textMessage);
         i.createCharacter(Message.getNextID(),"Bobby", "Wurp",
         "shapeName", "textureName");
         //List getInstalledModels(Message.getNextID());
         //List getInstalledTextures(Message.getNextID());
         i.getAllObjects(Message.getNextID());
         i.putObjectIn(Message.getNextID(),id1, id2);
         i.attack(Message.getNextID(),id1);
         */
    }
}

