import { Color, Vector4 } from "three";

const vertexShader = `
varying vec2 vUv;

void main() {
	vUv = uv;
	gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
	// csm_Position = position; 
}
`;

const black = new Color("#000");
const white = new Color("#fff");

export function createGradientShader({
	colors = [ black, white ],
	opacity = 1,
	stops = undefined,
	direction = "x",
	center: [ cx = 0.5, cy = 0.5 ] = [ 0.5, 0.5 ],
	csm = false
}) {
	const uniforms = {};
	const steps = !stops || !stops.length
		? colors.map((_, i, cArr) => (i / (cArr.length - 1)))
		: stops;

	colors.forEach((c, i) => {
		uniforms[ "color" + i ] = { value: c instanceof Color
			? new Vector4(c.r, c.g, c.b, opacity)
			: c
		}
	});

	let mixString;
	switch(direction) {
		case "x":
		case "y":
			mixString = `vUv.${direction}`;
			break;
		case "xy":
		case "yx":
			mixString = direction === "xy"
				? `(vUv.x + vUv.y) / 2.0`
				: `(vUv.x + 1.0 - vUv.y) / 2.0`
			break;
		case "radial":
			mixString = `clamp(distance(vUv, vec2(${cx}, ${cy})), 0.0, 1.0)`;
			break;
		default:
			console.warn(`Invalid gradient direction: ${direction}, defaulting to "x"`);
			mixString = "vUv.x";
			break;
	}

	const colorOutput = csm ? "csm_DiffuseColor": "gl_FragColor";

	const shader = {
		uniforms,
		vertexShader: csm
			? vertexShader.replace("// csm_", "csm_").replace("gl_Pos", "// gl_Pos")
			: vertexShader,
		fragmentShader: colors.length === 2
		? `
		uniform vec4 color0;
		uniform vec4 color1;
		varying vec2 vUv;
		void main() {
			${colorOutput} = vec4(mix(color0, color1, ${mixString}));
		}`
		:`
		${Object.keys(uniforms).map(u => `
		uniform vec4 ${u};
		`).join("\n")}

		varying vec2 vUv;

		void main() {
			float mixValue = ${mixString};
			if (mixValue < ${(steps[1]).toFixed(3)}) {
				${colorOutput} = vec4(mix(color0, color1, mixValue * ${(1 / steps[1]).toFixed(3)}));
			}
			${steps.slice(1, steps.length - 2).map((step, i) => `
			else if (mixValue < ${(steps[ i + 2 ]).toFixed(3)}) {
				${colorOutput} = vec4(mix(color${i + 1}, color${i + 2}, (mixValue - ${step.toFixed(3)}) * ${(1 / (steps[ i + 2 ] - step)).toFixed(3)}));
			}
			`).join("\n")}
			else {
				${colorOutput} = vec4(mix(color${colors.length - 2}, color${colors.length - 1}, (mixValue - ${(steps[ steps.length - 2 ]).toFixed(3)}) * ${(1 / (steps[ steps.length - 1 ] - steps[ steps.length - 2 ])).toFixed(3)}));
			}
		}
		`,
		transparent: true
	};
	return shader;
}