
import difflib
import sys
import unittest

import glitch
from glitch.camera import Camera
from glitch.context import Context

class Add(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def __repr__(self):
        return '%r + %r' % (self.a, self.b)

    def __add__(self, other):
        return Add(self, other)

class Or(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def __repr__(self):
        return '%r | %r' % (self.a, self.b)

    def __or__(self, other):
        return Or(self, other)

class FakeAttr(object):
    def __init__(self, obj, log, name):
        self.obj = obj
        self.log = log
        self.name = name

    def __call__(self, *args):
        self.log.append(
            '%s(%s)' % (self.name, ', '.join(map(repr, args))))

    def __repr__(self):
        return self.name

    def __add__(self, other):
        return Add(self, other)

    def __or__(self, other):
        return Or(self, other)

    def __cmp__(self, other):
        return cmp(
            (self.obj, self.name),
            (other.obj, other.name))

class WrappedMethod(FakeAttr):
    def __init__(self, obj, log, name, impl):
        FakeAttr.__init__(self, obj, log, name)
        self.impl = impl

    def __call__(self, *args):
        FakeAttr.__call__(self, *args)
        return self.impl(*args)

class FakeModule(object):
    def __init__(self, log):
        self.log = log

    def __getattribute__(self, name):
        d = object.__getattribute__(self, '__dict__')
        log = d['log']

        try:
            v = object.__getattribute__(self, name)
        except AttributeError:
            v = None

        if v is not None:
            if hasattr(v, '__call__'):
                return WrappedMethod(self, log, name, v)
            else:
                return v

        return FakeAttr(self, log, name)

class FakeGL(FakeModule):
    def __init__(self, *args):
        FakeModule.__init__(self, *args)
        self.next_texture_id = 1

    def glGenTextures(self, length):
        id = self.next_texture_id
        self.next_texture_id += 1
        return id

class FakeProgram(object):
    def __init__(self, log, id):
        self.log = log
        self.id = id

    def __enter__(self):
        self.log.append('program %d: enter' % self.id)

    def __exit__(self, type, exc, tb):
        self.log.append('program %d: exit' % self.id)

class FakeShaders(FakeModule):
    def __init__(self, *args):
        FakeModule.__init__(self, *args)
        self.next_shader_id = 1
        self.next_program_id = 1

    def compileShader(self, code, type):
        id = self.next_shader_id
        self.next_shader_id += 1
        return 'shader%d' % id

    def compileProgram(self, vshader, fshader):
        id = self.next_program_id
        self.next_program_id += 1
        return FakeProgram(self.log, id)

class LoggingNode(glitch.Node):
    def __init__(self, name, log, **kw):
        glitch.Node.__init__(self, **kw)
        self.name = name
        self.log = log

    def render(self, ctx):
        self.log.append('%s: begin' % self.name)
        glitch.Node.render(self, ctx)
        self.log.append('%s: end' % self.name)

    def draw(self, ctx):
        self.log.append('%s: draw' % self.name)

class GlitchTest(unittest.TestCase):
    def assertStringsEqual(self, expected, other):
        if expected != other:
            raise AssertionError(
                list(difflib.unified_diff(expected, other)))

    def setUp(self):
        self.log = []
        self.ctx = Context()
        self.fake_gl = FakeGL(self.log)
        self.fake_shaders = FakeShaders(self.log)
        self.undo = []

        # Dependency injection, Python style.

        for (name, module) in sys.modules.iteritems():
            if name.startswith('glitch.'):
                if hasattr(module, 'gl'):
                    self.undo.append((module, 'gl', module.gl))
                    module.gl = self.fake_gl

                if hasattr(module, 'shaders'):
                    self.undo.append((module, 'shaders', module.shaders))
                    module.shaders = self.fake_shaders

    def tearDown(self):
        for (module, name, value) in self.undo:
            setattr(module, name, value)

    def test_recurse(self):
        node = LoggingNode('a', self.log, children=[
            LoggingNode('b', self.log),
            LoggingNode('c', self.log)])
        node.render(self.ctx)
        self.assertStringsEqual([
            'a: begin',
            'a: draw',
            'b: begin', 'b: draw', 'b: end',
            'c: begin', 'c: draw', 'c: end',
            'a: end'
            ], self.log)

    def test_translate(self):
        node = glitch.Translate(x=1,
            children=[LoggingNode('a', self.log)])
        node.render(self.ctx)
        self.assertStringsEqual([
            'glMatrixMode(GL_MODELVIEW)',
            'glPushMatrix()',
            'glTranslate(1, 0, 0)',
            'a: begin',
            'a: draw',
            'a: end',
            'glMatrixMode(GL_MODELVIEW)',
            'glPopMatrix()'
            ], self.log)

    def test_scale(self):
        node = glitch.Scale(x=2,
            children=[LoggingNode('a', self.log)])
        node.render(self.ctx)
        self.assertStringsEqual([
            'glMatrixMode(GL_MODELVIEW)',
            'glPushMatrix()',
            'glScale(2, 1, 1)',
            'a: begin',
            'a: draw',
            'a: end',
            'glMatrixMode(GL_MODELVIEW)',
            'glPopMatrix()'
            ], self.log)

    def test_rotate(self):
        node = glitch.Rotate(angle=90, x=1,
            children=[LoggingNode('a', self.log)])
        node.render(self.ctx)
        self.assertStringsEqual([
            'glMatrixMode(GL_MODELVIEW)',
            'glPushMatrix()',
            'glRotate(90, 1, 0, 0)',
            'a: begin',
            'a: draw',
            'a: end',
            'glMatrixMode(GL_MODELVIEW)',
            'glPopMatrix()'
            ], self.log)

    def test_color(self):
        node = glitch.Color(r=1, g=2, b=3, a=4,
            children=[LoggingNode('a', self.log)])
        node.render(self.ctx)
        self.assertStringsEqual([
            'glPushAttrib(GL_CURRENT_BIT)',
            'glColor(1, 2, 3, 4)',
            'a: begin',
            'a: draw',
            'a: end',
            'glPopAttrib()'
            ], self.log)

    def test_camera(self):
        camera = Camera(
            children=[LoggingNode('a', self.log)])
        camera.context['w'] = 10
        camera.context['h'] = 10
        camera.render(None)
        self.assertStringsEqual([
            'glClearColor(0, 0, 0, 0)',
            'glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)',
            'glColor(1, 1, 1)',
            'glMaterialfv(GL_FRONT_AND_BACK, GL_AMBIENT_AND_DIFFUSE, '
                '(1, 1, 1, 1))',
            'glLightModelfv(GL_LIGHT_MODEL_AMBIENT, [0, 0, 0, 1])',
            'glPolygonMode(GL_FRONT_AND_BACK, GL_FILL)',
            'glViewport(0, 0, 10, 10)',
            'glMatrixMode(GL_MODELVIEW)',
            'glLoadIdentity()',
            'glMatrixMode(GL_PROJECTION)',
            'glLoadIdentity()',
            'glEnable(GL_DEPTH_TEST)',
            'glMatrixMode(GL_MODELVIEW)',
            'a: begin',
            'a: draw',
            'a: end',
            ], self.log)

    def test_camera_with_parent(self):
        camera = Camera(
            children=[LoggingNode('a', self.log)])
        camera.context['w'] = 20
        camera.context['h'] = 20
        self.ctx['w'] = 10
        self.ctx['h'] = 10
        camera.render(self.ctx)
        self.assertStringsEqual([
            # Save.
            'glMatrixMode(GL_PROJECTION)',
            'glPushMatrix()',
            'glMatrixMode(GL_MODELVIEW)',
            'glPushMatrix()',
            'glPushAttrib(GL_VIEWPORT_BIT | GL_TRANSFORM_BIT | '
                'GL_ENABLE_BIT | GL_LIGHTING_BIT | GL_CURRENT_BIT)',

            # Draw.
            'glClearColor(0, 0, 0, 0)',
            'glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)',
            'glColor(1, 1, 1)',
            'glMaterialfv(GL_FRONT_AND_BACK, GL_AMBIENT_AND_DIFFUSE, '
                '(1, 1, 1, 1))',
            'glLightModelfv(GL_LIGHT_MODEL_AMBIENT, [0, 0, 0, 1])',
            'glPolygonMode(GL_FRONT_AND_BACK, GL_FILL)',
            'glViewport(0, 0, 20, 20)',
            'glMatrixMode(GL_MODELVIEW)',
            'glLoadIdentity()',
            'glMatrixMode(GL_PROJECTION)',
            'glLoadIdentity()',
            'glEnable(GL_DEPTH_TEST)',
            'glMatrixMode(GL_MODELVIEW)',
            'a: begin',
            'a: draw',
            'a: end',

            # Restore.
            'glPopAttrib()',
            'glMatrixMode(GL_PROJECTION)',
            'glPopMatrix()',
            'glMatrixMode(GL_MODELVIEW)',
            'glPopMatrix()'
            ], self.log)

    def test_texture(self):
        texture = glitch.Texture(20, 20, 'abcdef')
        node = glitch.ApplyTexture(texture,
            children=[LoggingNode('a', self.log)])
        node.render(self.ctx)
        node.render(self.ctx)
        self.assertStringsEqual([
            # First render.
            'glActiveTexture(GL_TEXTURE0 + 0)',
            'glPushAttrib(GL_ENABLE_BIT | GL_TEXTURE_BIT)',
            'glEnable(GL_TEXTURE_2D)',
            'glGenTextures(1)',
            'glBindTexture(GL_TEXTURE_2D, 1)',
            'glTexParameterf('
                'GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)',
            'glTexParameterf('
                'GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)',
            'glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR)',
            'glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR)',
            "glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, 20, 20, 0, GL_RGBA, "
                "GL_UNSIGNED_BYTE, 'abcdef')",
            'glBindTexture(GL_TEXTURE_2D, 1)',
            'a: begin',
            'a: draw',
            'a: end',
            'glActiveTexture(GL_TEXTURE0 + 0)',
            'glPopAttrib()',

            # Second render (no parameters/upload).
            'glActiveTexture(GL_TEXTURE0 + 0)',
            'glPushAttrib(GL_ENABLE_BIT | GL_TEXTURE_BIT)',
            'glEnable(GL_TEXTURE_2D)',
            'glBindTexture(GL_TEXTURE_2D, 1)',
            'a: begin',
            'a: draw',
            'a: end',
            'glActiveTexture(GL_TEXTURE0 + 0)',
            'glPopAttrib()'
            ], self.log)

        self.log[:] = []
        assert texture in self.ctx['texture']
        texture.context_delete(self.ctx)
        assert texture not in self.ctx['texture']
        self.assertEquals(['glDeleteTextures(1)'], self.log)

    def test_shader(self):
        node = glitch.Shader(vertex='vertex', fragment='fragment',
            children=[LoggingNode('a', self.log)])
        node.render(self.ctx)
        node.render(self.ctx)
        print dir(self)
        self.assertStringsEqual([
            # First render.
            "compileShader('vertex', GL_VERTEX_SHADER)",
            "compileShader('fragment', GL_FRAGMENT_SHADER)",
            "compileProgram('shader1', 'shader2')",
            'program 1: enter',
            'a: begin',
            'a: draw',
            'a: end',
            'program 1: exit',

            # Second render.
            'program 1: enter',
            'a: begin',
            'a: draw',
            'a: end',
            'program 1: exit',
            ], self.log)

