/*
 * JBoss, the OpenSource EJB server
 *
 * Distributable under LGPL license.
 * See terms of license at gnu.org.
 */
package org.jboss.ejb.plugins.jrmp.server;

import java.awt.Component;
import java.beans.beancontext.BeanContextChildComponentProxy;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Method;
import java.lang.reflect.Constructor;
import java.net.UnknownHostException;
import java.rmi.ServerException;
import java.rmi.RemoteException;
import java.rmi.MarshalledObject;
import java.rmi.server.RemoteServer;
import java.rmi.server.RMIClientSocketFactory;
import java.rmi.server.RMIServerSocketFactory;
import java.rmi.server.UnicastRemoteObject;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.HashMap;
import java.util.Properties;

import javax.ejb.EJBMetaData;
import javax.ejb.EJBHome;
import javax.ejb.EJBObject;
import javax.naming.Name;
import javax.naming.InitialContext;
import javax.naming.Context;
import javax.naming.NamingException;
import javax.naming.NameNotFoundException;
import javax.transaction.Transaction;
import javax.transaction.TransactionManager;

import org.jboss.ejb.MethodInvocation;

import org.jboss.ejb.Container;
import org.jboss.ejb.ContainerInvokerContainer;
import org.jboss.ejb.Interceptor;
import org.jboss.ejb.ContainerInvoker;
import org.jboss.ejb.plugins.jrmp.interfaces.HomeProxy;
import org.jboss.ejb.plugins.jrmp.interfaces.HomeHandleImpl;
import org.jboss.ejb.plugins.jrmp.interfaces.RemoteMethodInvocation;
import org.jboss.ejb.plugins.jrmp.interfaces.StatelessSessionProxy;
import org.jboss.ejb.plugins.jrmp.interfaces.StatefulSessionProxy;
import org.jboss.ejb.plugins.jrmp.interfaces.EntityProxy;
import org.jboss.ejb.plugins.jrmp.interfaces.GenericProxy;
import org.jboss.ejb.plugins.jrmp.interfaces.ContainerRemote;
import org.jboss.ejb.plugins.jrmp.interfaces.EJBMetaDataImpl;

import org.jboss.tm.TransactionPropagationContextFactory;

import org.jboss.security.SecurityAssociation;

import org.jboss.logging.Logger;

import org.jboss.deployment.DeploymentException;
import org.jboss.metadata.XmlLoadable;
import org.jboss.metadata.MetaData;
import org.jboss.metadata.EntityMetaData;
import org.jboss.metadata.SessionMetaData;
import org.jboss.security.SecurityDomain;

import org.w3c.dom.Element;

/**
 *
 * @author Rickard berg (rickard.oberg@telkel.com)
 *	@author <a href="mailto:sebastien.alborini@m4x.org">Sebastien Alborini</a>
 * @author <a href="mailto:marc.fleury@telkel.com">Marc Fleury</a>
 *	@author <a href="mailto:jplindfo@cc.helsinki.fi">Juha Lindfors</a>
 *	@author <a href="mailto:Scott.Stark@jboss.org">Scott Stark</a>
 * @version $Revision: 1.37.4.7 $
 */
public class JRMPContainerInvoker
   extends RemoteServer
   implements ContainerRemote, ContainerInvoker, XmlLoadable
{
   static Logger log = Logger.getLogger(JRMPContainerInvoker.class);
   // Constants -----------------------------------------------------
   protected final static int ANONYMOUS_PORT = 0;

   // Attributes ----------------------------------------------------
   protected boolean optimize = false;
   /** The port the container will be exported on */
   protected int rmiPort = ANONYMOUS_PORT;
   /** An optional custom client socket factory */
   protected RMIClientSocketFactory clientSocketFactory;
   /** An optional custom server socket factory */
   protected RMIServerSocketFactory serverSocketFactory;
   /** The class name of the optional custom client socket factory */
   protected String clientSocketFactoryName;
   /** The class name of the optional custom server socket factory */
   protected String serverSocketFactoryName;
   /** The SecurityDomain instance for the ssl-domain setting */
   protected SecurityDomain sslDomain;
   /** The address to bind the rmi port on */
   protected String serverAddress;
   protected boolean jdk122 = false;
   protected Container container;
   protected ContainerInvokerContainer invokerContainer;
   protected String jndiName;
   protected EJBMetaDataImpl ejbMetaData;
   // The home can be one.
   protected EJBHome home;
   // The Stateless Object can be one.
   protected EJBObject statelessObject;

   protected Map beanMethodInvokerMap;
   protected Map homeMethodInvokerMap;
   
   protected ContainerInvoker ciDelegate; // Delegate depending on JDK version

   // Static --------------------------------------------------------

   private static TransactionPropagationContextFactory tpcFactory;

   // Constructors --------------------------------------------------

   // Public --------------------------------------------------------
   public void setOptimized(boolean optimize)
   {
      this.optimize = optimize;
      //DEBUG		log.debug("Container Invoker optimize set to '"+optimize+"'");
   }

   public boolean isOptimized()
   {
      //DEBUG  log.debug("Optimize in action: '"+optimize+"'");
      return optimize;
   }
   
   public String getJndiName()
   {
      return jndiName;
   }

   // ContainerService implementation -------------------------------
   public void setContainer(Container con)
   {
      this.container = con;
      this.invokerContainer = (ContainerInvokerContainer) con;
      ciDelegate.setContainer(con);
   }

   public void init()
      throws Exception
   {
      Context ctx = new InitialContext();

      jndiName = container.getBeanMetaData().getJndiName();

      // Get the transaction propagation context factory
      tpcFactory = (TransactionPropagationContextFactory)ctx.lookup("java:/TransactionPropagationContextExporter");

      // Set the transaction manager and transaction propagation
      // context factory of the GenericProxy class
      GenericProxy.setTransactionManager((TransactionManager)ctx.lookup("java:/TransactionManager"));
      GenericProxy.setTPCFactory(tpcFactory);

      // Create method mappings for container invoker
      Method[] methods = invokerContainer.getRemoteClass().getMethods();
      beanMethodInvokerMap = new HashMap();
      for (int i = 0; i < methods.length; i++)
         beanMethodInvokerMap.put(new Long(RemoteMethodInvocation.calculateHash(methods[i])), methods[i]);
      
      methods = invokerContainer.getHomeClass().getMethods();
      homeMethodInvokerMap = new HashMap();
      for (int i = 0; i < methods.length; i++)
         homeMethodInvokerMap.put(new Long(RemoteMethodInvocation.calculateHash(methods[i])), methods[i]);
         
      try
      {
         // Get the getEJBObjectMethod
         Method getEJBObjectMethod = Class.forName("javax.ejb.Handle").getMethod("getEJBObject", new Class[0]);

         // Hash it
         homeMethodInvokerMap.put(new Long(RemoteMethodInvocation.calculateHash(getEJBObjectMethod)),getEJBObjectMethod);
      }
      catch (Exception e)
      {
         log.error("getEJBObject", e);
      }

      // Create metadata
      /**
      Constructor signature is

      public EJBMetaDataImpl(Class remote,
      Class home,
      Class pkClass,
      boolean session,
      boolean statelessSession,
      HomeHandle homeHandle)
      */

      if (container.getBeanMetaData() instanceof EntityMetaData)
      {
         Class pkClass;
         EntityMetaData metaData = (EntityMetaData)container.getBeanMetaData();
         String pkClassName = metaData.getPrimaryKeyClass();
         try
         {
            if(pkClassName != null)
               pkClass = container.getClassLoader().loadClass(pkClassName);
            else
               pkClass = container.getClassLoader().loadClass(metaData.getEjbClass()).getField(metaData.getPrimKeyField()).getClass();
         } catch(NoSuchFieldException e)
         {
            log.error("Unable to identify Bean's Primary Key class!  Did you specify a primary key class and/or field?  Does that field exist?");
            throw new RuntimeException("Primary Key Problem");
         } catch(NullPointerException e)
         {
            log.error("Unable to identify Bean's Primary Key class!  Did you specify a primary key class and/or field?  Does that field exist?");
            throw new RuntimeException("Primary Key Problem");
         }
         ejbMetaData = new EJBMetaDataImpl(
            invokerContainer.getRemoteClass(),
            invokerContainer.getHomeClass(),
            pkClass,
            false, //Session
            false, //Stateless
            new HomeHandleImpl(jndiName));
      }
      else
      {
         if (((SessionMetaData)container.getBeanMetaData()).isStateless())
         {

            ejbMetaData = new EJBMetaDataImpl(
               invokerContainer.getRemoteClass(),
               invokerContainer.getHomeClass(),
               null, //No PK
               true, //Session
               true, //Stateless
               new HomeHandleImpl(jndiName));
         }
         // we are stateful
         else
         {

            ejbMetaData = new EJBMetaDataImpl(
               invokerContainer.getRemoteClass(),
               invokerContainer.getHomeClass(),
               null, //No PK
               true, //Session
               false,//Stateless
               new HomeHandleImpl(jndiName));
         }
      }
      
      ciDelegate.init();
   }

   public void start()
      throws Exception
   {
      try
      {
         // Export CI
         UnicastRemoteObject.exportObject(this, rmiPort,
             clientSocketFactory, serverSocketFactory);
         GenericProxy.addLocal(container.getBeanMetaData().getJndiName(), this);

         InitialContext context = new InitialContext();

         // Bind the home in the JNDI naming space
         rebind(
            // The context
            context,
            // Jndi name
            container.getBeanMetaData().getJndiName(),
            // The Home
            invokerContainer.getContainerInvoker().getEJBHome());

         // Bind a bare bones invoker in the JNDI invoker naming space
         rebind(
            // The context
            context,
            // JNDI name under the invokers moniker
            "invokers/"+container.getBeanMetaData().getJndiName(),
            // The invoker
            invokerContainer.getContainerInvoker());


         log.debug("Bound "+container.getBeanMetaData().getEjbName() + " to " + container.getBeanMetaData().getJndiName());
      } catch (IOException e)
      {
         throw new ServerException("Could not bind either home or invoker", e);
      }
   }

   public void stop()
   {
      try
      {
         InitialContext ctx = new InitialContext();
         ctx.unbind(container.getBeanMetaData().getJndiName());
         ctx.unbind("invokers/"+container.getBeanMetaData().getJndiName());
         UnicastRemoteObject.unexportObject(this, true);
      }
      catch (Exception e)
      {
         // ignore.
      }

      GenericProxy.removeLocal(container.getBeanMetaData().getJndiName());
      beanMethodInvokerMap.clear();
      homeMethodInvokerMap.clear();
      // Remove method mappings for container invoker
      Method[] methods = invokerContainer.getRemoteClass().getMethods();
      for (int i = 0; i < methods.length; i++)
         RemoteMethodInvocation.clearHash(methods[i]);

      methods = invokerContainer.getHomeClass().getMethods();
      for (int i = 0; i < methods.length; i++)
         RemoteMethodInvocation.clearHash(methods[i]);

      log.debug("Cleared method maps");
   }

   public void destroy()
   {
      container = null;
      invokerContainer = null;
      if( ciDelegate != null )
         ciDelegate.destroy();
      ciDelegate = null;
   }

   // ContainerInvoker implementation -------------------------------
   public EJBMetaData getEJBMetaData()
   {
      return ejbMetaData;
   }

   public EJBHome getEJBHome()
   {
      return ciDelegate.getEJBHome();
   }

   public EJBObject getStatelessSessionEJBObject()
      throws RemoteException
   {
      return ciDelegate.getStatelessSessionEJBObject();
   }

   public EJBObject getStatefulSessionEJBObject(Object id)
      throws RemoteException
   {
      return ciDelegate.getStatefulSessionEJBObject(id);
   }

   public EJBObject getEntityEJBObject(Object id)
      throws RemoteException
   {
      return ciDelegate.getEntityEJBObject(id);
   }

   public Collection getEntityCollection(Collection ids)
      throws RemoteException
   {
      return ciDelegate.getEntityCollection(ids);
   }

   // ContainerRemote implementation --------------------------------

   /**
    *  Invoke a Home interface method.
    */
   public MarshalledObject invokeHome(MarshalledObject mimo)
      throws Exception
   {
      ClassLoader oldCl = Thread.currentThread().getContextClassLoader();
      Thread.currentThread().setContextClassLoader(container.getClassLoader());

      try
      {
         RemoteMethodInvocation rmi = (RemoteMethodInvocation)mimo.get();
         rmi.setMethodMap(homeMethodInvokerMap);

         return new MarshalledObject(container.invokeHome(new MethodInvocation(null, rmi.getMethod(), rmi.getArguments(),
            rmi.getPrincipal(), rmi.getCredential(), rmi.getTransactionPropagationContext() )));
      } finally
      {
         Thread.currentThread().setContextClassLoader(oldCl);
      }
   }

   /**
    *  Invoke a Remote interface method.
    */
   public MarshalledObject invoke(MarshalledObject mimo)
      throws Exception
   {
      ClassLoader oldCl = Thread.currentThread().getContextClassLoader();
      Thread.currentThread().setContextClassLoader(container.getClassLoader());

      try
      {
         RemoteMethodInvocation rmi = (RemoteMethodInvocation)mimo.get();
         rmi.setMethodMap(beanMethodInvokerMap);
         Object tpc = rmi.getTransactionPropagationContext();

         return new MarshalledObject(container.invoke(new MethodInvocation(rmi.getId(), rmi.getMethod(), rmi.getArguments(),
            rmi.getPrincipal(), rmi.getCredential(), rmi.getTransactionPropagationContext() )));
      } finally
      {
         Thread.currentThread().setContextClassLoader(oldCl);
      }
   }

   /**
    *  Invoke a Home interface method.
    *  This is for optimized local calls.
    */
   public Object invokeHome(Method m, Object[] args, Transaction tx,
      Principal identity, Object credential)
      throws Exception
   {
      // Check if this call really can be optimized
      Class methodClass = m.getDeclaringClass();
      Class remoteClass = invokerContainer.getHomeClass();
      if ( methodClass.isAssignableFrom(remoteClass) == false )
      {
         RemoteMethodInvocation rmi = new RemoteMethodInvocation(null, m, args);

         // Set the transaction propagation context
         rmi.setTransactionPropagationContext(tpcFactory.getTransactionPropagationContext(tx));

         // Set the security stuff
         rmi.setPrincipal( SecurityAssociation.getPrincipal() );
         rmi.setCredential( SecurityAssociation.getCredential() );

         // Invoke on the container, enforce marshalling
         try
         {
            return invokeHome(new MarshalledObject(rmi)).get();
         } catch (Exception e)
         {
            throw (Exception)new MarshalledObject(e).get();
         }
      }

      // Set the right context classloader
      ClassLoader oldCl = Thread.currentThread().getContextClassLoader();
      Thread.currentThread().setContextClassLoader(container.getClassLoader());

      try
      {
         return container.invokeHome(new MethodInvocation(null, m, args, tx,
            identity, credential));
      } finally
      {
         Thread.currentThread().setContextClassLoader(oldCl);
      }
   }

   /**
    *  Invoke a Remote interface method.
    *  This is for optimized local calls.
    */
   public Object invoke(Object id, Method m, Object[] args, Transaction tx,
      Principal identity, Object credential )
      throws Exception
   {
      // Check if this call really can be optimized
      // If parent of callers classloader is != parent of our classloader -> not optimizable!
      //	   if (Thread.currentThread().getContextClassLoader().getParent() != container.getClassLoader().getParent())
      Class methodClass = m.getDeclaringClass();
      Class remoteClass = invokerContainer.getRemoteClass();
      if ( methodClass.isAssignableFrom(remoteClass) == false )
      {
         RemoteMethodInvocation rmi = new RemoteMethodInvocation(id, m, args);

         // Set the transaction propagation context
         rmi.setTransactionPropagationContext(tpcFactory.getTransactionPropagationContext(tx));

         // Set the security stuff
         rmi.setPrincipal( SecurityAssociation.getPrincipal() );
         rmi.setCredential( SecurityAssociation.getCredential() );

         // Invoke on the container, enforce marshalling
         try
         {
            return invoke(new MarshalledObject(rmi)).get();
         } catch (Exception e)
         {
            throw (Exception)new MarshalledObject(e).get();
         }
      }

      // Set the right context classloader
      ClassLoader oldCl = Thread.currentThread().getContextClassLoader();
      Thread.currentThread().setContextClassLoader(container.getClassLoader());

      try
      {
         return container.invoke(new MethodInvocation(id, m, args, tx, identity, credential));
      }
      finally
      {
         Thread.currentThread().setContextClassLoader(oldCl);
      }
   }


    // XmlLoadable implementation
    public void importXml(Element element) throws DeploymentException
    {
        Element optElement = MetaData.getUniqueChild(element, "Optimized");
        if( optElement != null )
        {
            String opt = MetaData.getElementContent(optElement);
            optimize = Boolean.valueOf(opt).booleanValue();
        }

        if ((System.getProperty("java.vm.version").compareTo("1.3") >= 0))
            jdk122 = false;
        else
            jdk122 = true;

        // Create delegate depending on JDK version
        if (jdk122)
        {
            ciDelegate = new org.jboss.ejb.plugins.jrmp12.server.JRMPContainerInvoker(this);
        }
        else
        {
            ciDelegate = new org.jboss.ejb.plugins.jrmp13.server.JRMPContainerInvoker(this);
        }

        try
        {
            Element portElement = MetaData.getUniqueChild(element, "RMIObjectPort");
            if( portElement != null )
            {
                String port = MetaData.getElementContent(portElement);
                rmiPort = Integer.parseInt(port);
            }
        }
        catch(NumberFormatException e)
        {
            rmiPort = ANONYMOUS_PORT;
        }
        catch(DeploymentException e)
        {
            rmiPort = ANONYMOUS_PORT;
        }

        // Load any custom socket factories
        ClassLoader loader = Thread.currentThread().getContextClassLoader();
        try
        {
            Element csfElement = MetaData.getOptionalChild(element, "RMIClientSocketFactory");
            if( csfElement != null )
            {
                clientSocketFactoryName = MetaData.getElementContent(csfElement);
            }
        }
        catch(Exception e)
        {
            log.error(e);
            clientSocketFactoryName = null;
        }
        try
        {
            Element ssfElement = MetaData.getOptionalChild(element, "RMIServerSocketFactory");
            if( ssfElement != null )
            {
                serverSocketFactoryName = MetaData.getElementContent(ssfElement);
            }
        }
        catch(Exception e)
        {
            log.error(e);
            serverSocketFactoryName = null;
        }
        // Load the optional ssl-domain giving the JNDI name of the SecurityDomain
        Element sslDomainElement = MetaData.getOptionalChild(element, "ssl-domain");
        if( sslDomainElement != null )
        {
           String domainName = MetaData.getElementContent(sslDomainElement);
           try
           {
            InitialContext iniCtx = new InitialContext();
            sslDomain = (SecurityDomain) iniCtx.lookup(domainName);
           }
           catch(Exception e)
           {
              throw new DeploymentException("Failed to location ssl-domain", e);
           }
        }

        Element addrElement = MetaData.getOptionalChild(element, "RMIServerSocketAddr");
        if( addrElement != null )
           this.serverAddress = MetaData.getElementContent(addrElement);
        loadCustomSocketFactories(loader);

        log.debug("Container Invoker RMI Port='"+(rmiPort == ANONYMOUS_PORT ? "Anonymous" : Integer.toString(rmiPort))+"'");
        log.debug("Container Invoker Client SocketFactory='"+(clientSocketFactory == null ? "Default" : clientSocketFactory.toString())+"'");
        log.debug("Container Invoker Server SocketFactory='"+(serverSocketFactory == null ? "Default" : serverSocketFactory.toString())+"'");
        log.debug("Container Invoker Server SocketAddr='"+(serverAddress == null ? "Default" : serverAddress)+"'");
        log.debug("Container Invoker Server sslDomain='"+(sslDomain == null ? "Default" : sslDomain.getSecurityDomain())+"'");
        log.debug("Container Invoker Optimize='"+optimize+"'");
   }


   // Package protected ---------------------------------------------

   // Protected -----------------------------------------------------
   protected void rebind(Context ctx, String name, Object val)
      throws NamingException
   {
      // Bind val to name in ctx, and make sure that all intermediate contexts exist

      Name n = ctx.getNameParser("").parse(name);
      while (n.size() > 1)
      {
         String ctxName = n.get(0);
         try
         {
            ctx = (Context)ctx.lookup(ctxName);
         } catch (NameNotFoundException e)
         {
            ctx = ctx.createSubcontext(ctxName);
         }
         n = n.getSuffix(1);
      }

      ctx.rebind(n.get(0), val);
   }

   // Private -------------------------------------------------------
   private void loadCustomSocketFactories(ClassLoader loader)
   {
        try
        {
            if( clientSocketFactoryName != null )
            {
                Class csfClass = loader.loadClass(clientSocketFactoryName);
                clientSocketFactory = (RMIClientSocketFactory) csfClass.newInstance();
            }
        }
        catch(Exception e)
        {
            log.error(e);
            clientSocketFactory = null;
        }
        try
        {
            if( serverSocketFactoryName != null )
            {
                Class ssfClass = loader.loadClass(serverSocketFactoryName);
                serverSocketFactory = (RMIServerSocketFactory) ssfClass.newInstance();
                /* See if the server socket supports setBindAddress(String)
                 is a specific bind address was specified
                */
                if( serverAddress != null )
                {
                   try
                   {
                      Class[] parameterTypes = {String.class};
                      Method m = ssfClass.getMethod("setBindAddress", parameterTypes);
                      Object[] args = {serverAddress};
                      m.invoke(serverSocketFactory, args);
                   }
                   catch(NoSuchMethodException e)
                   {
                      log.error("Socket factory does not support setBindAddress(String)");
                      // Go with default address
                   }
                   catch(Exception e)
                   {
                      log.error("Failed to setBindAddress="+serverAddress+" on socket factory");
                      // Go with default address
                   }
                }
                /* See if the server socket supports setSecurityDomain(SecurityDomain)
                 is an sslDomain was specified
                */
                if( sslDomain != null )
                {
                   try
                   {
                      Class[] parameterTypes = {SecurityDomain.class};
                      Method m = ssfClass.getMethod("setSecurityDomain", parameterTypes);
                      Object[] args = {sslDomain};
                      m.invoke(serverSocketFactory, args);
                   }
                   catch(NoSuchMethodException e)
                   {
                      log.error("Socket factory does not support setSecurityDomain(SecurityDomain)");
                   }
                   catch(Exception e)
                   {
                      log.error("Failed to setSecurityDomain="+sslDomain+" on socket factory");
                   }
                }
            }
            // If a bind address was specified create a DefaultSocketFactory
            else if( serverAddress != null )
            {
               DefaultSocketFactory defaultFactory = new DefaultSocketFactory();
               serverSocketFactory = defaultFactory;
               try
               {
                  defaultFactory.setBindAddress(serverAddress);
               }
               catch(UnknownHostException e)
               {
                  log.error("Failed to setBindAddress="+serverAddress+" on socket factory, "+e.getMessage());
               }
            }
        }
        catch(Exception e)
        {
            log.error(e);
            serverSocketFactory = null;
        }
   }

   // Inner classes -------------------------------------------------
}
