package org.sunflow.core.shader;

import org.sunflow.SunflowAPI;
import org.sunflow.core.ParameterList;
import org.sunflow.core.Ray;
import org.sunflow.core.ShadingState;
import org.sunflow.core.Texture;
import org.sunflow.image.Color;
import org.sunflow.math.Vector3;

public class TexturedShinyDiffuseShader extends ShinyDiffuseShader {
    private Texture tex;

    public TexturedShinyDiffuseShader() {
        tex = null;
    }

    @Override
    public boolean update(ParameterList pl, SunflowAPI api) {
        String filename = pl.getString("texture", null);
        if (filename != null)
            // EP : Made texture cache local to a SunFlow API instance
            tex = api.getTextureCache().getTexture(api.resolveTextureFilename(filename), false);
        return tex != null && super.update(pl, api);
    }

    @Override
    public Color getDiffuse(ShadingState state) {
        return tex.getPixel(state.getUV().x, state.getUV().y);
    }

    // EP : Added transparency management  
    @Override
    public Color getRadiance(ShadingState state) {
        Color opacity;
        if (isOpaque() || (opacity = getOpacity(state)).isWhite()) {
            // Pixel is fully opaque
            return super.getRadiance(state);
        } else {
            state.faceforward();
            // direct lighting
            state.initLightSamples();
            state.initCausticSamples();
            Color d = Color.sub(Color.WHITE, opacity);
            Vector3 refrDir = state.getRay().getDirection();
            Color refraction = state.traceRefraction(new Ray(state.getPoint(), refrDir), 0);
            d.mul(refraction);
            if (!state.includeSpecular()
                || opacity.isBlack()) { // No reflection when fully transparent
                return d;
            }
            float cos = state.getCosND();
            float dn = 2 * cos;
            Vector3 refDir = new Vector3();
            refDir.x = (dn * state.getNormal().x) + state.getRay().getDirection().x;
            refDir.y = (dn * state.getNormal().y) + state.getRay().getDirection().y;
            refDir.z = (dn * state.getNormal().z) + state.getRay().getDirection().z;
            Ray refRay = new Ray(state.getPoint(), refDir);
            // compute Fresnel term
            cos = 1 - cos;
            float cos2 = cos * cos;
            float cos5 = cos2 * cos2 * cos;

            Color ret = Color.white();
            Color r = Color.sub(Color.WHITE, opacity).mul(getShininess());
            ret.sub(r);
            ret.mul(cos5);
            ret.add(r);
            return d.add(ret.mul(state.traceReflection(refRay, 0)));
        }
    }
    
    @Override
    public boolean isOpaque() {
        return !(tex.isTransparent());
    }
    
    @Override
    public Color getOpacity(ShadingState state) {
        return tex.getOpacity(state.getUV().x, state.getUV().y);
    }
    // EP : End of modification
}