1 /** 2 Support the automatic implementation of test doubles via programmable mocks. 3 */ 4 module unit_threaded.mock; 5 6 import unit_threaded.from; 7 8 alias Identity(alias T) = T; 9 private enum isPrivate(T, string member) = !__traits(compiles, __traits(getMember, T, member)); 10 11 12 string implMixinStr(T)() { 13 import std.array: join; 14 import std.format : format; 15 import std.range : iota; 16 import std.traits: functionAttributes, FunctionAttribute, Parameters, ReturnType, arity; 17 import std.conv: text; 18 19 if(!__ctfe) return null; 20 21 string[] lines; 22 23 string getOverload(in string memberName, in int i) { 24 return `Identity!(__traits(getOverloads, T, "%s")[%s])` 25 .format(memberName, i); 26 } 27 28 foreach(memberName; __traits(allMembers, T)) { 29 30 static if(!isPrivate!(T, memberName)) { 31 32 alias member = Identity!(__traits(getMember, T, memberName)); 33 34 static if(__traits(isVirtualMethod, member)) { 35 foreach(i, overload; __traits(getOverloads, T, memberName)) { 36 37 static if(!(functionAttributes!member & FunctionAttribute.const_) && 38 !(functionAttributes!member & FunctionAttribute.const_)) { 39 40 enum overloadName = text(memberName, "_", i); 41 42 enum overloadString = getOverload(memberName, i); 43 lines ~= "private alias %s_parameters = Parameters!(%s);".format(overloadName, overloadString); 44 lines ~= "private alias %s_returnType = ReturnType!(%s);".format(overloadName, overloadString); 45 46 static if(functionAttributes!member & FunctionAttribute.nothrow_) 47 enum tryIndent = " "; 48 else 49 enum tryIndent = ""; 50 51 static if(is(ReturnType!member == void)) 52 enum returnDefault = ""; 53 else { 54 enum varName = overloadName ~ `_returnValues`; 55 lines ~= `%s_returnType[] %s;`.format(overloadName, varName); 56 lines ~= ""; 57 enum returnDefault = [` if(` ~ varName ~ `.length > 0) {`, 58 ` auto ret = ` ~ varName ~ `[0];`, 59 ` ` ~ varName ~ ` = ` ~ varName ~ `[1..$];`, 60 ` return ret;`, 61 ` } else`, 62 ` return %s_returnType.init;`.format(overloadName)]; 63 } 64 65 lines ~= `override ` ~ overloadName ~ "_returnType " ~ memberName ~ 66 typeAndArgsParens!(Parameters!overload)(overloadName) ~ " " ~ 67 functionAttributesString!member ~ ` {`; 68 69 static if(functionAttributes!member & FunctionAttribute.nothrow_) 70 lines ~= "try {"; 71 72 lines ~= tryIndent ~ ` calledFuncs ~= "` ~ memberName ~ `";`; 73 lines ~= tryIndent ~ ` calledValues ~= tuple` ~ argNamesParens(arity!member) ~ `.to!string;`; 74 75 static if(functionAttributes!member & FunctionAttribute.nothrow_) 76 lines ~= " } catch(Exception) {}"; 77 78 lines ~= returnDefault; 79 80 lines ~= `}`; 81 lines ~= ""; 82 } 83 } 84 } 85 } 86 } 87 88 return lines.join("\n"); 89 } 90 91 private string argNamesParens(int N) @safe pure { 92 if(!__ctfe) return null; 93 return "(" ~ argNames(N) ~ ")"; 94 } 95 96 private string argNames(int N) @safe pure { 97 import std.range; 98 import std.algorithm; 99 import std.conv; 100 101 if(!__ctfe) return null; 102 return iota(N).map!(a => "arg" ~ a.to!string).join(", "); 103 } 104 105 private string typeAndArgsParens(T...)(string prefix) { 106 import std.array; 107 import std.conv; 108 import std.format : format; 109 110 if(!__ctfe) return null; 111 112 string[] parts; 113 114 foreach(i, t; T) 115 parts ~= "%s_parameters[%s] arg%s".format(prefix, i, i); 116 return "(" ~ parts.join(", ") ~ ")"; 117 } 118 119 private string functionAttributesString(alias F)() { 120 import std.traits: functionAttributes, FunctionAttribute; 121 import std.array: join; 122 123 if(!__ctfe) return null; 124 125 string[] parts; 126 127 const attrs = functionAttributes!F; 128 129 if(attrs & FunctionAttribute.pure_) parts ~= "pure"; 130 if(attrs & FunctionAttribute.nothrow_) parts ~= "nothrow"; 131 if(attrs & FunctionAttribute.trusted) parts ~= "@trusted"; 132 if(attrs & FunctionAttribute.safe) parts ~= "@safe"; 133 if(attrs & FunctionAttribute.nogc) parts ~= "@nogc"; 134 if(attrs & FunctionAttribute.system) parts ~= "@system"; 135 // const and immutable can't be done since the mock needs 136 // to alter state 137 // if(attrs & FunctionAttribute.const_) parts ~= "const"; 138 // if(attrs & FunctionAttribute.immutable_) parts ~= "immutable"; 139 if(attrs & FunctionAttribute.shared_) parts ~= "shared"; 140 141 return parts.join(" "); 142 } 143 144 mixin template MockImplCommon() { 145 bool _verified; 146 string[] expectedFuncs; 147 string[] calledFuncs; 148 string[] expectedValues; 149 string[] calledValues; 150 151 void expect(string funcName, V...)(auto ref V values) { 152 import std.conv: to; 153 import std.typecons: tuple; 154 155 expectedFuncs ~= funcName; 156 static if(V.length > 0) 157 expectedValues ~= tuple(values).to!string; 158 else 159 expectedValues ~= ""; 160 } 161 162 void expectCalled(string func, string file = __FILE__, size_t line = __LINE__, V...)(auto ref V values) { 163 expect!func(values); 164 verify(file, line); 165 _verified = false; 166 } 167 168 void verify(string file = __FILE__, size_t line = __LINE__) @safe pure { 169 import std.range: repeat, take, join; 170 import std.conv: to; 171 import unit_threaded.should: fail, UnitTestException; 172 173 if(_verified) 174 fail("Mock already _verified", file, line); 175 176 _verified = true; 177 178 for(int i = 0; i < expectedFuncs.length; ++i) { 179 180 if(i >= calledFuncs.length) 181 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " did not happen", file, line); 182 183 if(expectedFuncs[i] != calledFuncs[i]) 184 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " but got " ~ calledFuncs[i] ~ 185 " instead", 186 file, line); 187 188 if(expectedValues[i] != calledValues[i] && expectedValues[i] != "") 189 throw new UnitTestException([expectedFuncs[i] ~ " was called with unexpected " ~ calledValues[i], 190 " ".repeat.take(expectedFuncs[i].length + 4).join ~ 191 "instead of the expected " ~ expectedValues[i]] , 192 file, line); 193 } 194 } 195 } 196 197 private enum isString(alias T) = is(typeof(T) == string); 198 199 /** 200 A mock object that conforms to an interface/class. 201 */ 202 struct Mock(T) { 203 204 MockAbstract _impl; 205 alias _impl this; 206 207 class MockAbstract: T { 208 import std.conv: to; 209 import std.traits: Parameters, ReturnType; 210 import std.typecons: tuple; 211 212 //pragma(msg, "\nimplMixinStr for ", T, "\n\n", implMixinStr!T, "\n\n"); 213 mixin(implMixinStr!T); 214 mixin MockImplCommon; 215 } 216 217 /// 218 this(int/* force constructor*/) { 219 _impl = new MockAbstract; 220 } 221 222 /// 223 ~this() pure @safe { 224 if(!_verified) verify; 225 } 226 227 /// Set the returnValue of a function to certain values. 228 void returnValue(string funcName, V...)(V values) { 229 assertFunctionIsVirtual!funcName; 230 return returnValue!(0, funcName)(values); 231 } 232 233 /** 234 This version takes overloads into account. i is the overload 235 index. e.g.: 236 --------- 237 interface Interface { void foo(int); void foo(string); } 238 auto m = mock!Interface; 239 m.returnValue!(0, "foo"); // int overload 240 m.returnValue!(1, "foo"); // string overload 241 --------- 242 */ 243 void returnValue(int i, string funcName, V...)(V values) { 244 assertFunctionIsVirtual!funcName; 245 import std.conv: text; 246 enum varName = funcName ~ text(`_`, i, `_returnValues`); 247 foreach(v; values) 248 mixin(varName ~ ` ~= v;`); 249 } 250 251 private static void assertFunctionIsVirtual(string funcName)() { 252 alias member = Identity!(__traits(getMember, T, funcName)); 253 254 static assert(__traits(isVirtualMethod, member), 255 "Cannot use returnValue on '" ~ funcName ~ "'"); 256 } 257 } 258 259 private string importsString(string module_, string[] Modules...) { 260 if(!__ctfe) return null; 261 262 auto ret = `import ` ~ module_ ~ ";\n"; 263 foreach(extraModule; Modules) { 264 ret ~= `import ` ~ extraModule ~ ";\n"; 265 } 266 return ret; 267 } 268 269 /// Helper function for creating a Mock object. 270 auto mock(T)() { 271 return Mock!T(0); 272 } 273 274 /// 275 @("mock interface positive test no params") 276 @safe pure unittest { 277 interface Foo { 278 int foo(int, string) @safe pure; 279 void bar() @safe pure; 280 } 281 282 int fun(Foo f) { 283 return 2 * f.foo(5, "foobar"); 284 } 285 286 auto m = mock!Foo; 287 m.expect!"foo"; 288 fun(m); 289 } 290 291 292 /// 293 @("mock interface positive test with params") 294 @safe pure unittest { 295 import unit_threaded.asserts; 296 297 interface Foo { 298 int foo(int, string) @safe pure; 299 void bar() @safe pure; 300 } 301 302 int fun(Foo f) { 303 return 2 * f.foo(5, "foobar"); 304 } 305 306 auto m = mock!Foo; 307 m.expect!"foo"(5, "foobar"); 308 fun(m); 309 } 310 311 312 /// 313 @("interface expectCalled") 314 @safe pure unittest { 315 interface Foo { 316 int foo(int, string) @safe pure; 317 void bar() @safe pure; 318 } 319 320 int fun(Foo f) { 321 return 2 * f.foo(5, "foobar"); 322 } 323 324 auto m = mock!Foo; 325 fun(m); 326 m.expectCalled!"foo"(5, "foobar"); 327 } 328 329 /// 330 @("interface return value") 331 @safe pure unittest { 332 333 interface Foo { 334 int timesN(int i) @safe pure; 335 } 336 337 int fun(Foo f) { 338 return f.timesN(3) * 2; 339 } 340 341 auto m = mock!Foo; 342 m.returnValue!"timesN"(42); 343 immutable res = fun(m); 344 assert(res == 84); 345 } 346 347 /// 348 @("interface return values") 349 @safe pure unittest { 350 351 interface Foo { 352 int timesN(int i) @safe pure; 353 } 354 355 int fun(Foo f) { 356 return f.timesN(3) * 2; 357 } 358 359 auto m = mock!Foo; 360 m.returnValue!"timesN"(42, 12); 361 assert(fun(m) == 84); 362 assert(fun(m) == 24); 363 assert(fun(m) == 0); 364 } 365 366 struct ReturnValues(string function_, T...) if(from!"std.meta".allSatisfy!(isValue, T)) { 367 alias funcName = function_; 368 alias Values = T; 369 370 static auto values() { 371 typeof(T[0])[] ret; 372 foreach(val; T) { 373 ret ~= val; 374 } 375 return ret; 376 } 377 } 378 379 enum isReturnValue(alias T) = is(T: ReturnValues!U, U...); 380 enum isValue(alias T) = is(typeof(T)); 381 382 383 /** 384 Version of mockStruct that accepts 0 or more values of the same 385 type. Whatever function is called on it, these values will 386 be returned one by one. The limitation is that if more than one 387 function is called on the mock, they all return the same type 388 */ 389 auto mockStruct(T...)(auto ref T returns) { 390 391 struct Mock { 392 393 MockImpl* _impl; 394 alias _impl this; 395 396 static struct MockImpl { 397 398 static if(T.length > 0) { 399 alias FirstType = typeof(returns[0]); 400 private FirstType[] _returnValues; 401 } 402 403 mixin MockImplCommon; 404 405 auto opDispatch(string funcName, V...)(auto ref V values) { 406 407 import std.conv: to; 408 import std.typecons: tuple; 409 410 calledFuncs ~= funcName; 411 calledValues ~= tuple(values).to!string; 412 413 static if(T.length > 0) { 414 415 if(_returnValues.length == 0) return typeof(_returnValues[0]).init; 416 auto ret = _returnValues[0]; 417 _returnValues = _returnValues[1..$]; 418 return ret; 419 } 420 } 421 } 422 } 423 424 Mock m; 425 m._impl = new Mock.MockImpl; 426 static if(T.length > 0) { 427 foreach(r; returns) 428 m._impl._returnValues ~= r; 429 } 430 431 return m; 432 } 433 434 /** 435 Version of mockStruct that accepts a compile-time mapping 436 of function name to return values. Each template parameter 437 must be a value of type `ReturnValues` 438 */ 439 auto mockStruct(T...)() if(T.length > 0 && from!"std.meta".allSatisfy!(isReturnValue, T)) { 440 441 struct Mock { 442 mixin MockImplCommon; 443 444 int[string] _retIndices; 445 446 auto opDispatch(string funcName, V...)(auto ref V values) { 447 448 import std.conv: to; 449 import std.typecons: tuple; 450 451 calledFuncs ~= funcName; 452 calledValues ~= tuple(values).to!string; 453 454 foreach(retVal; T) { 455 static if(retVal.funcName == funcName) { 456 return retVal.values[_retIndices[funcName]++]; 457 } 458 } 459 } 460 461 auto lefoofoo() { 462 return T[0].values[_retIndices["greet"]++]; 463 } 464 465 } 466 467 Mock mock; 468 469 foreach(retVal; T) { 470 mock._retIndices[retVal.funcName] = 0; 471 } 472 473 return mock; 474 } 475 476 /// 477 @("mock struct positive") 478 @safe pure unittest { 479 void fun(T)(T t) { 480 t.foobar; 481 } 482 auto m = mockStruct; 483 m.expect!"foobar"; 484 fun(m); 485 m.verify; 486 } 487 488 489 /// 490 @("mock struct values positive") 491 @safe pure unittest { 492 void fun(T)(T t) { 493 t.foobar(2, "quux"); 494 } 495 496 auto m = mockStruct; 497 m.expect!"foobar"(2, "quux"); 498 fun(m); 499 m.verify; 500 } 501 502 503 /// 504 @("struct return value") 505 @safe pure unittest { 506 507 int fun(T)(T f) { 508 return f.timesN(3) * 2; 509 } 510 511 auto m = mockStruct(42, 12); 512 assert(fun(m) == 84); 513 assert(fun(m) == 24); 514 assert(fun(m) == 0); 515 m.expectCalled!"timesN"; 516 } 517 518 /// 519 @("struct expectCalled") 520 @safe pure unittest { 521 void fun(T)(T t) { 522 t.foobar(2, "quux"); 523 } 524 525 auto m = mockStruct; 526 fun(m); 527 m.expectCalled!"foobar"(2, "quux"); 528 } 529 530 /// 531 @("mockStruct different return types for different functions") 532 @safe pure unittest { 533 auto m = mockStruct!(ReturnValues!("length", 5), 534 ReturnValues!("greet", "hello")); 535 assert(m.length == 5); 536 assert(m.greet("bar") == "hello"); 537 m.expectCalled!"length"; 538 m.expectCalled!"greet"("bar"); 539 } 540 541 /// 542 @("mockStruct different return types for different functions and multiple return values") 543 @safe pure unittest { 544 auto m = mockStruct!(ReturnValues!("length", 5, 3), 545 ReturnValues!("greet", "hello", "g'day")); 546 assert(m.length == 5); 547 m.expectCalled!"length"; 548 assert(m.length == 3); 549 m.expectCalled!"length"; 550 551 assert(m.greet("bar") == "hello"); 552 m.expectCalled!"greet"("bar"); 553 assert(m.greet("quux") == "g'day"); 554 m.expectCalled!"greet"("quux"); 555 } 556 557 558 /** 559 A mock struct that always throws. 560 */ 561 auto throwStruct(E = from!"unit_threaded.should".UnitTestException, R = void)() { 562 563 struct Mock { 564 565 R opDispatch(string funcName, string file = __FILE__, size_t line = __LINE__, V...) 566 (auto ref V values) { 567 throw new E(funcName ~ " was called", file, line); 568 } 569 } 570 571 return Mock(); 572 } 573 574 /// 575 @("throwStruct default") 576 @safe pure unittest { 577 import std.exception: assertThrown; 578 import unit_threaded.should: UnitTestException; 579 auto m = throwStruct; 580 assertThrown!UnitTestException(m.foo); 581 assertThrown!UnitTestException(m.bar(1, "foo")); 582 }