1 module unit_threaded.mock;
2 
3 import std.traits;
4 import std.traits: allSameType, allSatisfy;
5 import unit_threaded.should: UnitTestException;
6 
7 alias Identity(alias T) = T;
8 private enum isPrivate(T, string member) = !__traits(compiles, __traits(getMember, T, member));
9 
10 
11 string implMixinStr(T)() {
12     import std.array: join;
13     import std.format : format;
14     import std.range : iota;
15     import std.traits: functionAttributes, FunctionAttribute, Parameters, arity;
16     import std.conv: text;
17 
18     if(!__ctfe) return null;
19 
20     string[] lines;
21 
22     string getOverload(in string memberName, in int i) {
23         return `Identity!(__traits(getOverloads, T, "%s")[%s])`
24             .format(memberName, i);
25     }
26 
27     foreach(memberName; __traits(allMembers, T)) {
28 
29         static if(!isPrivate!(T, memberName)) {
30 
31             alias member = Identity!(__traits(getMember, T, memberName));
32 
33             static if(__traits(isVirtualMethod, member)) {
34                 foreach(i, overload; __traits(getOverloads, T, memberName)) {
35 
36                     static if(!(functionAttributes!member & FunctionAttribute.const_) &&
37                               !(functionAttributes!member & FunctionAttribute.const_)) {
38 
39                         enum overloadName = text(memberName, "_", i);
40 
41                         enum overloadString = getOverload(memberName, i);
42                         lines ~= "private alias %s_parameters = Parameters!(%s);".format(overloadName, overloadString);
43                         lines ~= "private alias %s_returnType = ReturnType!(%s);".format(overloadName, overloadString);
44 
45                         static if(functionAttributes!member & FunctionAttribute.nothrow_)
46                             enum tryIndent = "    ";
47                         else
48                             enum tryIndent = "";
49 
50                         static if(is(ReturnType!member == void))
51                             enum returnDefault = "";
52                         else {
53                             enum varName = overloadName ~ `_returnValues`;
54                             lines ~= `%s_returnType[] %s;`.format(overloadName, varName);
55                             lines ~= "";
56                             enum returnDefault = [`    if(` ~ varName ~ `.length > 0) {`,
57                                                   `        auto ret = ` ~ varName ~ `[0];`,
58                                                   `        ` ~ varName ~ ` = ` ~ varName ~ `[1..$];`,
59                                                   `        return ret;`,
60                                                   `    } else`,
61                                                   `        return %s_returnType.init;`.format(overloadName)];
62                         }
63 
64                         lines ~= `override ` ~ overloadName ~ "_returnType " ~ memberName ~
65                             typeAndArgsParens!(Parameters!overload)(overloadName) ~ " " ~
66                             functionAttributesString!member ~ ` {`;
67 
68                         static if(functionAttributes!member & FunctionAttribute.nothrow_)
69                             lines ~= "try {";
70 
71                         lines ~= tryIndent ~ `    calledFuncs ~= "` ~ memberName ~ `";`;
72                         lines ~= tryIndent ~ `    calledValues ~= tuple` ~ argNamesParens(arity!member) ~ `.to!string;`;
73 
74                         static if(functionAttributes!member & FunctionAttribute.nothrow_)
75                             lines ~= "    } catch(Exception) {}";
76 
77                         lines ~= returnDefault;
78 
79                         lines ~= `}`;
80                         lines ~= "";
81                     }
82                 }
83             }
84         }
85     }
86 
87     return lines.join("\n");
88 }
89 
90 private string argNamesParens(int N) @safe pure {
91     if(!__ctfe) return null;
92     return "(" ~ argNames(N) ~ ")";
93 }
94 
95 private string argNames(int N) @safe pure {
96     import std.range;
97     import std.algorithm;
98     import std.conv;
99 
100     if(!__ctfe) return null;
101     return iota(N).map!(a => "arg" ~ a.to!string).join(", ");
102 }
103 
104 private string typeAndArgsParens(T...)(string prefix) {
105     import std.array;
106     import std.conv;
107     import std.format : format;
108 
109     if(!__ctfe) return null;
110 
111     string[] parts;
112 
113     foreach(i, t; T)
114         parts ~= "%s_parameters[%s] arg%s".format(prefix, i, i);
115     return "(" ~ parts.join(", ") ~ ")";
116 }
117 
118 private string functionAttributesString(alias F)() {
119     import std.traits: functionAttributes, FunctionAttribute;
120     import std.array: join;
121 
122     if(!__ctfe) return null;
123 
124     string[] parts;
125 
126     const attrs = functionAttributes!F;
127 
128     if(attrs & FunctionAttribute.pure_) parts ~= "pure";
129     if(attrs & FunctionAttribute.nothrow_) parts ~= "nothrow";
130     if(attrs & FunctionAttribute.trusted) parts ~= "@trusted";
131     if(attrs & FunctionAttribute.safe) parts ~= "@safe";
132     if(attrs & FunctionAttribute.nogc) parts ~= "@nogc";
133     if(attrs & FunctionAttribute.system) parts ~= "@system";
134     // const and immutable can't be done since the mock needs
135     // to alter state
136     // if(attrs & FunctionAttribute.const_) parts ~= "const";
137     // if(attrs & FunctionAttribute.immutable_) parts ~= "immutable";
138     if(attrs & FunctionAttribute.shared_) parts ~= "shared";
139 
140     return parts.join(" ");
141 }
142 
143 mixin template MockImplCommon() {
144     bool _verified;
145     string[] expectedFuncs;
146     string[] calledFuncs;
147     string[] expectedValues;
148     string[] calledValues;
149 
150     void expect(string funcName, V...)(auto ref V values) @safe pure {
151         import std.conv: to;
152         import std.typecons: tuple;
153 
154         expectedFuncs ~= funcName;
155         static if(V.length > 0)
156             expectedValues ~= tuple(values).to!string;
157         else
158             expectedValues ~= "";
159     }
160 
161     void expectCalled(string func, string file = __FILE__, size_t line = __LINE__, V...)(auto ref V values) {
162         expect!func(values);
163         verify(file, line);
164         _verified = false;
165     }
166 
167     void verify(string file = __FILE__, size_t line = __LINE__) @safe pure {
168         import std.range: repeat, take, join;
169         import std.conv: to;
170         import unit_threaded.should: fail, UnitTestException;
171 
172         if(_verified)
173             fail("Mock already _verified", file, line);
174 
175         _verified = true;
176 
177         for(int i = 0; i < expectedFuncs.length; ++i) {
178 
179             if(i >= calledFuncs.length)
180                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " did not happen", file, line);
181 
182             if(expectedFuncs[i] != calledFuncs[i])
183                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " but got " ~ calledFuncs[i] ~
184                      " instead",
185                      file, line);
186 
187             if(expectedValues[i] != calledValues[i] && expectedValues[i] != "")
188                 throw new UnitTestException([expectedFuncs[i] ~ " was called with unexpected " ~ calledValues[i],
189                                              " ".repeat.take(expectedFuncs[i].length + 4).join ~
190                                              "instead of the expected " ~ expectedValues[i]] ,
191                                             file, line);
192         }
193     }
194 }
195 
196 private enum isString(alias T) = is(typeof(T) == string);
197 
198 struct Mock(T) {
199 
200     MockAbstract _impl;
201     alias _impl this;
202 
203     class MockAbstract: T {
204         import std.conv: to;
205         import std.traits: Parameters, ReturnType;
206         import std.typecons: tuple;
207 
208         //pragma(msg, "\nimplMixinStr for ", T, "\n\n", implMixinStr!T, "\n\n");
209         mixin(implMixinStr!T);
210         mixin MockImplCommon;
211     }
212 
213     this(int/* force constructor*/) {
214         _impl = new MockAbstract;
215     }
216 
217     ~this() pure @safe {
218         if(!_verified) verify;
219     }
220 
221     void returnValue(string funcName, V...)(V values) {
222         assertFunctionIsVirtual!funcName;
223         return returnValue!(0, funcName)(values);
224     }
225 
226     /**
227        This version takes overloads into account. i is the overload
228        index. e.g.:
229        ---------
230        interface Interface { void foo(int); void foo(string); }
231        auto m = mock!Interface;
232        m.returnValue!(0, "foo"); // int overload
233        m.returnValue!(1, "foo"); // string overload
234        ---------
235      */
236     void returnValue(int i, string funcName, V...)(V values) {
237         assertFunctionIsVirtual!funcName;
238         import std.conv: text;
239         enum varName = funcName ~ text(`_`, i, `_returnValues`);
240         foreach(v; values)
241             mixin(varName ~ ` ~=  v;`);
242     }
243 
244     private static void assertFunctionIsVirtual(string funcName)() {
245         alias member = Identity!(__traits(getMember, T, funcName));
246 
247         static assert(__traits(isVirtualMethod, member),
248                       "Cannot use returnValue on '" ~ funcName ~ "'");
249     }
250 }
251 
252 private string importsString(string module_, string[] Modules...) {
253     if(!__ctfe) return null;
254 
255     auto ret = `import ` ~ module_ ~ ";\n";
256     foreach(extraModule; Modules) {
257         ret ~= `import ` ~ extraModule ~ ";\n";
258     }
259     return ret;
260 }
261 
262 auto mock(T)() {
263     return Mock!T(0);
264 }
265 
266 
267 @("mock interface positive test no params")
268 @safe pure unittest {
269     interface Foo {
270         int foo(int, string) @safe pure;
271         void bar() @safe pure;
272     }
273 
274     int fun(Foo f) {
275         return 2 * f.foo(5, "foobar");
276     }
277 
278     auto m = mock!Foo;
279     m.expect!"foo";
280     fun(m);
281 }
282 
283 @("mock interface positive test with params")
284 @safe pure unittest {
285     import unit_threaded.asserts;
286 
287     interface Foo {
288         int foo(int, string) @safe pure;
289         void bar() @safe pure;
290     }
291 
292     int fun(Foo f) {
293         return 2 * f.foo(5, "foobar");
294     }
295 
296     {
297         auto m = mock!Foo;
298         m.expect!"foo"(5, "foobar");
299         fun(m);
300     }
301 
302     {
303         auto m = mock!Foo;
304         m.expect!"foo"(6, "foobar");
305         fun(m);
306         assertExceptionMsg(m.verify,
307                            `    source/unit_threaded/mock.d:123 - foo was called with unexpected Tuple!(int, string)(5, "foobar")` ~ "\n" ~
308                            `    source/unit_threaded/mock.d:123 -        instead of the expected Tuple!(int, string)(6, "foobar")`);
309     }
310 
311     {
312         auto m = mock!Foo;
313         m.expect!"foo"(5, "quux");
314         fun(m);
315         assertExceptionMsg(m.verify,
316                            `    source/unit_threaded/mock.d:123 - foo was called with unexpected Tuple!(int, string)(5, "foobar")` ~ "\n" ~
317                            `    source/unit_threaded/mock.d:123 -        instead of the expected Tuple!(int, string)(5, "quux")`);
318     }
319 }
320 
321 
322 @("mock interface negative test")
323 @safe pure unittest {
324     import unit_threaded.should;
325 
326     interface Foo {
327         int foo(int, string) @safe pure;
328     }
329 
330     auto m = mock!Foo;
331     m.expect!"foo";
332     m.verify.shouldThrowWithMessage("Expected nth 0 call to foo did not happen");
333 }
334 
335 // can't be in the unit test itself
336 version(unittest)
337 private class Class {
338     abstract int foo(int, string) @safe pure;
339     final int timesTwo(int i) @safe pure nothrow const { return i * 2; }
340     int timesThree(int i) @safe pure nothrow const { return i * 3; }
341     int timesThreeMutable(int i) @safe pure nothrow { return i * 3; }
342 }
343 
344 @("mock class positive test")
345 @safe pure unittest {
346 
347     int fun(Class f) {
348         return 2 * f.foo(5, "foobar");
349     }
350 
351     auto m = mock!Class;
352     m.expect!"foo";
353     fun(m);
354 }
355 
356 
357 @("mock interface multiple calls")
358 @safe pure unittest {
359     interface Foo {
360         int foo(int, string) @safe pure;
361         int bar(int) @safe pure;
362     }
363 
364     void fun(Foo f) {
365         f.foo(3, "foo");
366         f.bar(5);
367         f.foo(4, "quux");
368     }
369 
370     auto m = mock!Foo;
371     m.expect!"foo"(3, "foo");
372     m.expect!"bar"(5);
373     m.expect!"foo"(4, "quux");
374     fun(m);
375     m.verify;
376 }
377 
378 @("interface expectCalled")
379 @safe pure unittest {
380     interface Foo {
381         int foo(int, string) @safe pure;
382         void bar() @safe pure;
383     }
384 
385     int fun(Foo f) {
386         return 2 * f.foo(5, "foobar");
387     }
388 
389     auto m = mock!Foo;
390     fun(m);
391     m.expectCalled!"foo"(5, "foobar");
392 }
393 
394 @("interface return value")
395 @safe pure unittest {
396     import unit_threaded.should;
397 
398     interface Foo {
399         int timesN(int i) @safe pure;
400     }
401 
402     int fun(Foo f) {
403         return f.timesN(3) * 2;
404     }
405 
406     auto m = mock!Foo;
407     m.returnValue!"timesN"(42);
408     immutable res = fun(m);
409     res.shouldEqual(84);
410 }
411 
412 @("interface return values")
413 @safe pure unittest {
414     import unit_threaded.should;
415 
416     interface Foo {
417         int timesN(int i) @safe pure;
418     }
419 
420     int fun(Foo f) {
421         return f.timesN(3) * 2;
422     }
423 
424     auto m = mock!Foo;
425     m.returnValue!"timesN"(42, 12);
426     fun(m).shouldEqual(84);
427     fun(m).shouldEqual(24);
428     fun(m).shouldEqual(0);
429 }
430 
431 struct ReturnValues(string function_, T...) if(allSatisfy!(isValue, T)) {
432     alias funcName = function_;
433     alias Values = T;
434 
435     static auto values() {
436         typeof(T[0])[] ret;
437         foreach(val; T) {
438             ret ~= val;
439         }
440         return ret;
441     }
442 }
443 
444 enum isReturnValue(alias T) = is(T: ReturnValues!U, U...);
445 enum isValue(alias T) = is(typeof(T));
446 
447 
448 /**
449    Version of mockStruct that accepts 0 or more values of the same
450    type. Whatever function is called on it, these values will
451    be returned one by one. The limitation is that if more than one
452    function is called on the mock, they all return the same type
453  */
454 auto mockStruct(T...)(auto ref T returns) {
455 
456     struct Mock {
457 
458         MockImpl* _impl;
459         alias _impl this;
460 
461         static struct MockImpl {
462 
463             static if(T.length > 0) {
464                 alias FirstType = typeof(returns[0]);
465                 private FirstType[] _returnValues;
466             }
467 
468             mixin MockImplCommon;
469 
470             auto opDispatch(string funcName, V...)(auto ref V values) {
471 
472                 import std.conv: to;
473                 import std.typecons: tuple;
474 
475                 calledFuncs ~= funcName;
476                 calledValues ~= tuple(values).to!string;
477 
478                 static if(T.length > 0) {
479 
480                     if(_returnValues.length == 0) return typeof(_returnValues[0]).init;
481                     auto ret = _returnValues[0];
482                     _returnValues = _returnValues[1..$];
483                     return ret;
484                 }
485             }
486         }
487     }
488 
489     Mock m;
490     m._impl = new Mock.MockImpl;
491     static if(T.length > 0) {
492         foreach(r; returns)
493             m._impl._returnValues ~= r;
494     }
495 
496     return m;
497 }
498 
499 // /**
500 //    Version of mockStruct that accepts a compile-time mapping
501 //    of function name to return values. Each template parameter
502 //    must be a value of type `ReturnValues`
503 //  */
504 
505 auto mockStruct(T...)() if(T.length > 0 && allSatisfy!(isReturnValue, T)) {
506 
507     struct Mock {
508         mixin MockImplCommon;
509 
510         int[string] _retIndices;
511 
512         auto opDispatch(string funcName, V...)(auto ref V values) {
513 
514             import std.conv: to;
515             import std.typecons: tuple;
516 
517             calledFuncs ~= funcName;
518             calledValues ~= tuple(values).to!string;
519 
520             foreach(retVal; T) {
521                 static if(retVal.funcName == funcName) {
522                     return retVal.values[_retIndices[funcName]++];
523                 }
524             }
525         }
526 
527         auto lefoofoo() {
528             return T[0].values[_retIndices["greet"]++];
529         }
530 
531     }
532 
533     Mock mock;
534 
535     foreach(retVal; T) {
536         mock._retIndices[retVal.funcName] = 0;
537     }
538 
539     return mock;
540 }
541 
542 
543 @("mock struct positive")
544 @safe pure unittest {
545     void fun(T)(T t) {
546         t.foobar;
547     }
548     auto m = mockStruct;
549     m.expect!"foobar";
550     fun(m);
551     m.verify;
552 }
553 
554 @("mock struct negative")
555 @safe pure unittest {
556     import unit_threaded.asserts;
557 
558     auto m = mockStruct;
559     m.expect!"foobar";
560     assertExceptionMsg(m.verify,
561                        "    source/unit_threaded/mock.d:123 - Expected nth 0 call to foobar did not happen\n");
562 
563 }
564 
565 
566 @("mock struct values positive")
567 @safe pure unittest {
568     void fun(T)(T t) {
569         t.foobar(2, "quux");
570     }
571 
572     auto m = mockStruct;
573     m.expect!"foobar"(2, "quux");
574     fun(m);
575     m.verify;
576 }
577 
578 @("mock struct values negative")
579 @safe pure unittest {
580     import unit_threaded.asserts;
581 
582     void fun(T)(T t) {
583         t.foobar(2, "quux");
584     }
585 
586     auto m = mockStruct;
587     m.expect!"foobar"(3, "quux");
588     fun(m);
589     assertExceptionMsg(m.verify,
590                        "    source/unit_threaded/mock.d:123 - foobar was called with unexpected Tuple!(int, string)(2, \"quux\")\n" ~
591                        "    source/unit_threaded/mock.d:123 -           instead of the expected Tuple!(int, string)(3, \"quux\")");
592 }
593 
594 
595 @("struct return value")
596 @safe pure unittest {
597     import unit_threaded.should;
598 
599     int fun(T)(T f) {
600         return f.timesN(3) * 2;
601     }
602 
603     auto m = mockStruct(42, 12);
604     fun(m).shouldEqual(84);
605     fun(m).shouldEqual(24);
606     fun(m).shouldEqual(0);
607     m.expectCalled!"timesN";
608 }
609 
610 @("struct expectCalled")
611 @safe pure unittest {
612     void fun(T)(T t) {
613         t.foobar(2, "quux");
614     }
615 
616     auto m = mockStruct;
617     fun(m);
618     m.expectCalled!"foobar"(2, "quux");
619 }
620 
621 @("mockStruct different return types for different functions")
622 @safe pure unittest {
623     import unit_threaded.should: shouldEqual;
624     auto m = mockStruct!(ReturnValues!("length", 5),
625                          ReturnValues!("greet", "hello"));
626     m.length.shouldEqual(5);
627     m.greet("bar").shouldEqual("hello");
628     m.expectCalled!"length";
629     m.expectCalled!"greet"("bar");
630 }
631 
632 @("mockStruct different return types for different functions and multiple return values")
633 @safe pure unittest {
634     import unit_threaded.should: shouldEqual;
635     auto m = mockStruct!(ReturnValues!("length", 5, 3),
636                          ReturnValues!("greet", "hello", "g'day"));
637     m.length.shouldEqual(5);
638     m.expectCalled!"length";
639     m.length.shouldEqual(3);
640     m.expectCalled!"length";
641 
642     m.greet("bar").shouldEqual("hello");
643     m.expectCalled!"greet"("bar");
644     m.greet("quux").shouldEqual("g'day");
645     m.expectCalled!"greet"("quux");
646 }
647 
648 
649 @("const(ubyte)[] return type]")
650 @safe pure unittest {
651     interface Interface {
652         const(ubyte)[] fun();
653     }
654 
655     auto m = mock!Interface;
656 }
657 
658 @("safe pure nothrow")
659 @safe pure unittest {
660     interface Interface {
661         int twice(int i) @safe pure nothrow /*@nogc*/;
662     }
663     auto m = mock!Interface;
664 }
665 
666 @("issue 63")
667 @safe pure unittest {
668     import unit_threaded.should;
669 
670     interface InterfaceWithOverloads {
671         int func(int) @safe pure;
672         int func(string) @safe pure;
673     }
674     alias ov = Identity!(__traits(allMembers, InterfaceWithOverloads)[0]);
675     auto m = mock!InterfaceWithOverloads;
676     m.returnValue!(0, "func")(3); // int overload
677     m.returnValue!(1, "func")(7); // string overload
678     m.expect!"func"("foo");
679     m.func("foo").shouldEqual(7);
680     m.verify;
681 }
682 
683 
684 auto throwStruct(E = UnitTestException, R = void)() {
685 
686     struct Mock {
687 
688         R opDispatch(string funcName, string file = __FILE__, size_t line = __LINE__, V...)
689                     (auto ref V values) {
690             throw new E(funcName ~ " was called", file, line);
691         }
692     }
693 
694     return Mock();
695 }
696 
697 @("throwStruct default")
698 @safe pure unittest {
699     import unit_threaded.should: shouldThrow;
700     auto m = throwStruct;
701     m.foo.shouldThrow!UnitTestException;
702     m.bar(1, "foo").shouldThrow!UnitTestException;
703 }
704 
705 version(testing_unit_threaded) {
706     class FooException: Exception {
707         import std.exception: basicExceptionCtors;
708         mixin basicExceptionCtors;
709     }
710 
711 
712     @("throwStruct custom")
713         @safe pure unittest {
714         import unit_threaded.should: shouldThrow;
715 
716         auto m = throwStruct!FooException;
717         m.foo.shouldThrow!FooException;
718         m.bar(1, "foo").shouldThrow!FooException;
719     }
720 }
721 
722 
723 @("throwStruct return value type")
724 @safe pure unittest {
725     import unit_threaded.asserts;
726     auto m = throwStruct!(UnitTestException, int);
727     int i;
728     assertExceptionMsg(i = m.foo,
729                        "    source/unit_threaded/mock.d:123 - foo was called");
730     assertExceptionMsg(i = m.bar,
731                        "    source/unit_threaded/mock.d:123 - bar was called");
732 }
733 
734 @("issue 68")
735 @safe pure unittest {
736     import unit_threaded.should;
737 
738     int fun(Class f) {
739         // f.timesTwo is mocked to return 2, no matter what's passed in
740         return f.timesThreeMutable(2);
741     }
742 
743     auto m = mock!Class;
744     m.expect!"timesThreeMutable"(2);
745     m.returnValue!("timesThreeMutable")(42);
746     fun(m).shouldEqual(42);
747 }
748 
749 @("issue69")
750 unittest {
751     import unit_threaded.should;
752 
753     static interface InterfaceWithOverloadedFuncs {
754         string over();
755         string over(string str);
756     }
757 
758     static class ClassWithOverloadedFuncs {
759         string over() { return "oops"; }
760         string over(string str) { return "oopsie"; }
761     }
762 
763     auto iMock = mock!InterfaceWithOverloadedFuncs;
764     iMock.returnValue!(0, "over")("bar");
765     iMock.returnValue!(1, "over")("baz");
766     iMock.over.shouldEqual("bar");
767     iMock.over("zing").shouldEqual("baz");
768 
769     auto cMock = mock!ClassWithOverloadedFuncs;
770     cMock.returnValue!(0, "over")("bar");
771     cMock.returnValue!(1, "over")("baz");
772     cMock.over.shouldEqual("bar");
773     cMock.over("zing").shouldEqual("baz");
774 }