/*
 * JBoss, Home of Professional Open Source
 * Copyright 2008, Red Hat Middleware LLC, and others contributors as indicated
 * by the @authors tag. All rights reserved.
 * See the copyright.txt in the distribution for a
 * full listing of individual contributors.
 * This copyrighted material is made available to anyone wishing to use,
 * modify, copy, or redistribute it subject to the terms and conditions
 * of the GNU Lesser General Public License, v. 2.1.
 * This program is distributed in the hope that it will be useful, but WITHOUT A
 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
 * PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
 * You should have received a copy of the GNU Lesser General Public License,
 * v.2.1 along with this distribution; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 * MA 02110-1301, USA.
 *
 * (C) 2005-2008, JBoss Inc.
 */
package org.jboss.soa.esb.listeners.gateway.http;

import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Enumeration;
import java.util.Map;
import java.util.StringTokenizer;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.log4j.Logger;
import org.jboss.internal.soa.esb.listeners.war.Servlet;
import org.jboss.internal.soa.esb.publish.ContractInfo;
import org.jboss.internal.soa.esb.publish.ContractProvider;
import org.jboss.internal.soa.esb.publish.ContractProviderLifecycleResource;
import org.jboss.internal.soa.esb.util.StreamUtils;
import org.jboss.soa.esb.ConfigurationException;
import org.jboss.soa.esb.Service;
import org.jboss.soa.esb.client.ServiceInvoker;
import org.jboss.soa.esb.couriers.FaultMessageException;
import org.jboss.soa.esb.helpers.ConfigTree;
import org.jboss.soa.esb.http.HttpServletSecUtil;
import org.jboss.soa.esb.lifecycle.LifecycleResourceException;
import org.jboss.soa.esb.listeners.ListenerTagNames;
import org.jboss.soa.esb.listeners.config.mappers120.HttpGatewayMapper;
import org.jboss.soa.esb.listeners.message.MessageComposer;
import org.jboss.soa.esb.listeners.message.MessageDeliverException;
import org.jboss.soa.esb.message.Message;
import org.jboss.soa.esb.services.registry.RegistryException;
import org.jboss.soa.esb.services.security.PublicCryptoUtil;
import org.jboss.soa.esb.util.ClassUtil;

/**
 * Http Gateway Servlet.
 * <p/>
 * An instance of this class gets configured for each http-listener urlPattern.
 *
 * @author <a href="mailto:tom.fennelly@jboss.com">tom.fennelly@jboss.com</a>
 * @see org.jboss.internal.soa.esb.listeners.war.HttpGatewayDeploymentFactory
 */
public class HttpGatewayServlet extends HttpServlet {

    private static final Logger logger = Logger.getLogger(HttpGatewayServlet.class);

    public static final String PAYLOAD_AS = "payloadAs";
    public static final String EXCEPTION_MAPPINGS = "httpExceptionMappings";

    public static final String ASYNC_SERVICE_INVOKE = "asyncServiceInvoke";
    public static final String ASYNC_STATUS_CODE = "asyncStatusCode";
    public static final String ASYNC_PAYLOAD = "asyncPayloadPath";
    public static final String ASYNC_PAYLOAD_CONTENT_TYPE= "asyncPayloadContentType";
    public static final String ASYNC_PAYLOAD_CHARACTER_ENCODING = "asyncPayloadCharacterEncoding";

    private Service service;
    private ServiceInvoker serviceInvoker;
    private String endpointAddress;
    private ContractInfo contractInfo = null;
    private boolean asyncInvoke;
    private int asyncStatusCode = HttpServletResponse.SC_OK;
    private byte[] asyncPayload;
    private String asyncContentType;
    private String asyncCharacterEncoding;
    private MessageComposer<HttpRequestWrapper> messageComposer;
    private long blockingTimeout = 30000L;
    private Map<String, Integer> exceptionMappings;
    private int[] allowedPorts = new int[0];

    public void init(ServletConfig config) throws ServletException {
        service = new Service(
        	config.getInitParameter(ListenerTagNames.TARGET_SERVICE_CATEGORY_TAG),
        	config.getInitParameter(ListenerTagNames.TARGET_SERVICE_NAME_TAG)
        );
        try {
            serviceInvoker = new ServiceInvoker(service);
        } catch (MessageDeliverException e) {
            throw new ServletException("Unable to create ServiceInvoker for Service '" + service + "'.", e);
        }
    	endpointAddress = config.getInitParameter(Servlet.ENDPOINT_ADDRESS);

        String allow = config.getInitParameter(Servlet.ALLOWED_PORTS);
        if (allow != null) {
            StringTokenizer tokens = new StringTokenizer(allow, ",");
            int noOfItems = tokens.countTokens();
            this.allowedPorts = new int[noOfItems];
            for (int i = 0; i < noOfItems; i++) {
                this.allowedPorts[i] = Integer.parseInt(tokens.nextToken());
            }
        }

        ConfigTree configTree = toConfigTree(config);

        try {
            messageComposer = MessageComposer.Factory.getInstance(configTree.getAttribute(ListenerTagNames.GATEWAY_COMPOSER_CLASS_TAG, HttpMessageComposer.class.getName()), configTree);
        } catch (ConfigurationException e) {
            throw new ServletException("Failed to create message composer.", e);
        } catch (MessageDeliverException e) {
            throw new ServletException("Failed to create message composer.", e);
        }

        asyncInvoke = configTree.getBooleanAttribute(ASYNC_SERVICE_INVOKE, false);
        if(asyncInvoke) {
            String asyncSCConfig = configTree.getAttribute(ASYNC_STATUS_CODE);
            try {
                asyncStatusCode = Integer.parseInt(asyncSCConfig);
            } catch (NumberFormatException e) {
                throw new ServletException("Invalid static asynchronous response code configuration '" + asyncSCConfig + "'.", e);
            }

            String payloadPath = configTree.getAttribute(ASYNC_PAYLOAD);
            if(payloadPath != null) {
                try {
                    asyncPayload = readStaticAsyncResponse(payloadPath);
                } catch (ConfigurationException e) {
                    throw new ServletException("Invalid Exception to HTTP Status mapping configuration.", e);
                }
                asyncContentType = configTree.getAttribute(ASYNC_PAYLOAD_CONTENT_TYPE);
                asyncCharacterEncoding = configTree.getAttribute(ASYNC_PAYLOAD_CHARACTER_ENCODING);
            }
        } else {
            blockingTimeout = configTree.getLongAttribute("synchronousTimeout", 30000L);
        }

        String exceptionMappingsCSV = configTree.getAttribute(EXCEPTION_MAPPINGS);
        if(exceptionMappingsCSV != null) {
            try {
                exceptionMappings = HttpGatewayMapper.decodeExceptionMappingsCSV(exceptionMappingsCSV);
            } catch (ConfigurationException e) {
                throw new ServletException("Invalid Exception to HTTP Status mapping configuration.", e);
            }
        }
    }

    protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        int noOfItems = this.allowedPorts.length;
        int port = req.getServerPort();
        boolean allow = false;
        if (noOfItems == 0) {
            allow = true;
        }
        else {
            for (int i = 0; i < noOfItems; i++) {
                if (allowedPorts[i] == port) {
                    allow = true;
                    break;
                }
            }
        }

        if (allow) {
            processServiceRequest(req, resp);
        } else {
            resp.sendError(HttpServletResponse.SC_NOT_FOUND);
        }
    }

    private void processServiceRequest(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        // if it's a wsdl request, serve up the contract then short-circuit
        String query = req.getQueryString();
        if (query != null && query.toLowerCase().startsWith("wsdl")) {
            handleWsdlRequest(req, resp);
            return;
        }
        
        HttpRequestWrapper wrapper = new HttpRequestWrapper(req, resp, service);

        Message inMessage;
        try {
            inMessage = messageComposer.compose(wrapper);
        } catch (MessageDeliverException e) {
            throw new ServletException("Failed to create message.", e);
        }

        // Add any servlet authentication details to the message if there were non
        // presented in the message payload e.g. on a SOAP message...
        if(!PublicCryptoUtil.INSTANCE.isAuthRequestOnMessage(inMessage)) {
            HttpServletSecUtil.addAuthDetailsToMessage(req, inMessage);
        }

        Message outMessage;
        try {
            // Dispatch the message to the action pipeline...
            if(!asyncInvoke) {
                outMessage = serviceInvoker.deliverSync(inMessage, blockingTimeout);

                // Set the mep as a header on the response...
                resp.setHeader(ASYNC_SERVICE_INVOKE, "false");
            } else {
                serviceInvoker.deliverAsync(inMessage);

                resp.setStatus(asyncStatusCode);

                // Set the mep as a header on the response...
                resp.setHeader(ASYNC_SERVICE_INVOKE, "true");

                if(asyncPayload != null) {
                    resp.setContentLength(asyncPayload.length);
                    resp.setContentType(asyncContentType);
                    if(asyncCharacterEncoding != null) {
                        resp.setCharacterEncoding(asyncCharacterEncoding);
                    }

                    // Only write to the output after setting all headers etc...
                    resp.getOutputStream().write(asyncPayload);
                } else {
                    resp.setContentLength(0);
                }

                return;
            }
        } catch (MessageDeliverException e) {
            throw new ServletException("Failed to deliver message.", e);
        } catch (RegistryException e) {
            throw new ServletException("Failed to deliver message.", e);
        } catch (FaultMessageException e) {
            Throwable cause = e.getCause();
            if(cause != null && exceptionMappings != null && exceptionMappings.containsKey(cause.getClass().getName())) {
                String exceptionClass = cause.getClass().getName();

                resp.setStatus(exceptionMappings.get(exceptionClass));
                resp.setHeader("Exception", exceptionClass);

                e.printStackTrace(resp.getWriter());
                resp.setContentType("text/plain");

                return;
            } else {
                throw new ServletException("Failed to deliver message.", e);
            }
        }

        if(outMessage != null) {
            try {
                messageComposer.decompose(outMessage, wrapper);
            } catch (MessageDeliverException e) {
                throw new ServletException("Failed to decompose response message.", e);
            }
        } else {
            resp.setContentLength(0);
            resp.setStatus(HttpServletResponse.SC_NO_CONTENT);
            return;
        }
    }

    private ConfigTree toConfigTree(ServletConfig config) {
        ConfigTree configTree = new ConfigTree("config");
        Enumeration<?> configNames = config.getInitParameterNames();

        while(configNames.hasMoreElements()) {
            String name = (String) configNames.nextElement();
            configTree.setAttribute(name, config.getInitParameter(name));
        }

        return configTree;
    }

    private byte[] readStaticAsyncResponse(String payloadPath) throws ConfigurationException {
        InputStream stream = ClassUtil.getResourceAsStream(payloadPath, HttpGatewayServlet.class);

        if(stream == null) {
            throw new ConfigurationException("Failed to access static HTTP response payload file '" + payloadPath + "' on classpath.");
        }

        try {
            return StreamUtils.readStream(stream);
        } finally {
            try {
                stream.close();
            } catch (IOException e) {
                logger.error("Unexpected Error closing static HTTP response payload file '" + payloadPath + "' ", e);
            }
        }
    }
    
    private void handleWsdlRequest(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
		String mimeType;
		String data;
		ContractInfo contract = getContract();
		if (contract != null) {
			mimeType = contract.getMimeType();
			String resource = req.getParameter("resource");
			if (resource != null) {
				data = contract.getResource(resource);
			} else {
				data = contract.getData();
			}
			if (data != null) {
				data = data.replaceAll( "@REQUEST_URL@", req.getRequestURL().toString() );
			} else {
				data = "";
			}
		} else {
			mimeType = "text/xml";
			data = "<definitions/>";
		}
		resp.setCharacterEncoding("UTF-8");
		resp.setContentType(mimeType);
		byte[] bytes = data.getBytes("UTF-8"); 
		resp.setContentLength(bytes.length);
		OutputStream os = new BufferedOutputStream(resp.getOutputStream());
		os.write(bytes);
		os.flush();
    }
    
    private ContractInfo getContract() throws ServletException, IOException {
    	if (contractInfo == null) {
    		synchronized (this) {
    			if (contractInfo == null) {
    	        	ContractProvider contractProvider;
    	        	try {
    	        		contractProvider = ContractProviderLifecycleResource.getContractProvider(service.getCategory(), service.getName());
    	        	} catch (LifecycleResourceException lre) {
    	        		throw new ServletException(lre);
    	        	}
    	        	if (contractProvider != null) {
    	        		contractInfo = contractProvider.provideContract(service, endpointAddress);
    	        	}
    			}
    		}
    	}
    	return contractInfo;
    }
    
}
